diff --git a/coap/api/transport.go b/coap/api/transport.go index 64d8fabb44..3f83dca91d 100644 --- a/coap/api/transport.go +++ b/coap/api/transport.go @@ -4,6 +4,7 @@ package api import ( + "bytes" "context" "fmt" "io" @@ -66,8 +67,15 @@ func MakeCoAPHandler(svc coap.Service, l *slog.Logger) mux.HandlerFunc { return handler } -func sendResp(w mux.ResponseWriter, resp *message.Message) { - if err := w.Client().WriteMessage(resp); err != nil { +func sendResp(ctx context.Context, w mux.ResponseWriter, resp *message.Message) { + m := w.Conn().AcquireMessage(ctx) + m.SetCode(resp.Code) + m.SetBody(bytes.NewReader(resp.Payload)) + m.SetToken(resp.Token) + for _, option := range resp.Options { + m.SetOptionBytes(option.ID, option.Value) + } + if err := w.Conn().WriteMessage(m); err != nil { logger.Warn(fmt.Sprintf("Can't set response: %s", err)) } } @@ -75,11 +83,10 @@ func sendResp(w mux.ResponseWriter, resp *message.Message) { func handler(w mux.ResponseWriter, m *mux.Message) { resp := message.Message{ Code: codes.Content, - Token: m.Token, - Context: m.Context, + Token: m.Token(), Options: make(message.Options, 0, 16), } - defer sendResp(w, &resp) + defer sendResp(m.Context(), w, &resp) msg, err := decodeMessage(m) if err != nil { logger.Warn(fmt.Sprintf("Error decoding message: %s", err)) @@ -92,12 +99,12 @@ func handler(w mux.ResponseWriter, m *mux.Message) { resp.Code = codes.Unauthorized return } - switch m.Code { + switch m.Code() { case codes.GET: - err = handleGet(m.Context, m, w.Client(), msg, key) + err = handleGet(m.Context(), m, w.Conn(), msg, key) case codes.POST: resp.Code = codes.Created - err = service.Publish(m.Context, key, msg) + err = nil default: err = svcerr.ErrNotFound } @@ -116,9 +123,9 @@ func handler(w mux.ResponseWriter, m *mux.Message) { } } -func handleGet(ctx context.Context, m *mux.Message, c mux.Client, msg *messaging.Message, key string) error { +func handleGet(ctx context.Context, m *mux.Message, c mux.Conn, msg *messaging.Message, key string) error { var obs uint32 - obs, err := m.Options.Observe() + obs, err := m.Observe() if err != nil { logger.Warn(fmt.Sprintf("Error reading observe option: %s", err)) return errBadOptions @@ -131,10 +138,10 @@ func handleGet(ctx context.Context, m *mux.Message, c mux.Client, msg *messaging } func decodeMessage(msg *mux.Message) (*messaging.Message, error) { - if msg.Options == nil { + if msg.Options() == nil { return &messaging.Message{}, errBadOptions } - path, err := msg.Options.Path() + path, err := msg.Path() if err != nil { return &messaging.Message{}, err } @@ -156,7 +163,7 @@ func decodeMessage(msg *mux.Message) (*messaging.Message, error) { } if msg.Body != nil { - buff, err := io.ReadAll(msg.Body) + buff, err := io.ReadAll(msg.Body()) if err != nil { return ret, err } @@ -166,10 +173,10 @@ func decodeMessage(msg *mux.Message) (*messaging.Message, error) { } func parseKey(msg *mux.Message) (string, error) { - if obs, _ := msg.Options.Observe(); obs != 0 && msg.Code == codes.GET { + if obs, _ := msg.Observe(); obs != 0 && msg.Code() == codes.GET { return "", nil } - authKey, err := msg.Options.GetString(message.URIQuery) + authKey, err := msg.Options().GetString(message.URIQuery) if err != nil { return "", err } diff --git a/coap/client.go b/coap/client.go index a71d896e56..9e520acb46 100644 --- a/coap/client.go +++ b/coap/client.go @@ -36,7 +36,7 @@ type Client interface { var ErrOption = errors.New("unable to set option") type client struct { - client mux.Client + client mux.Conn token message.Token observe uint32 logger *slog.Logger @@ -57,13 +57,10 @@ func (c *client) Done() <-chan struct{} { } func (c *client) Cancel() error { - m := message.Message{ - Code: codes.Content, - Token: c.token, - Context: context.Background(), - Options: make(message.Options, 0, 16), - } - if err := c.client.WriteMessage(&m); err != nil { + pm := c.client.AcquireMessage(context.Background()) + pm.SetCode(codes.Content) + pm.SetToken(c.token) + if err := c.client.WriteMessage(pm); err != nil { c.logger.Error(fmt.Sprintf("Error sending message: %s.", err)) } return c.client.Close() @@ -74,12 +71,10 @@ func (c *client) Token() string { } func (c *client) Handle(msg *messaging.Message) error { - m := message.Message{ - Code: codes.Content, - Token: c.token, - Context: c.client.Context(), - Body: bytes.NewReader(msg.GetPayload()), - } + pm := c.client.AcquireMessage(context.Background()) + pm.SetCode(codes.Content) + pm.SetToken(c.token) + pm.SetBody(bytes.NewReader(msg.Payload)) atomic.AddUint32(&c.observe, 1) var opts message.Options @@ -103,6 +98,8 @@ func (c *client) Handle(msg *messaging.Message) error { return fmt.Errorf("cannot set options to response: %w", err) } - m.Options = opts - return c.client.WriteMessage(&m) + for _, option := range opts { + pm.SetOptionBytes(option.ID, option.Value) + } + return c.client.WriteMessage(pm) } diff --git a/docker/nginx/nginx-key.conf b/docker/nginx/nginx-key.conf index f561229a9b..4dfb27f9be 100644 --- a/docker/nginx/nginx-key.conf +++ b/docker/nginx/nginx-key.conf @@ -188,13 +188,13 @@ http { } } -# MQTT stream { include snippets/stream_access_log.conf; + # MQTT # Include single-node or multiple-node (cluster) upstream include snippets/mqtt-upstream.conf; - + server { listen ${MG_NGINX_MQTT_PORT}; listen [::]:${MG_NGINX_MQTT_PORT}; @@ -205,6 +205,20 @@ stream { proxy_pass mqtt_cluster; } + + # CoAP + include snippets/coap-upstream.conf; + + server { + listen ${MF_NGINX_COAP_PORT}; + listen [::]:${MF_NGINX_COAP_PORT}; + listen ${MF_NGINX_COAPS_PORT} ssl; + listen [::]:${MF_NGINX_COAPS_PORT} ssl; + + include snippets/ssl.conf; + + proxy_pass coap_cluster; + } } error_log info.log info; diff --git a/docker/nginx/nginx-x509.conf b/docker/nginx/nginx-x509.conf index 49f0170e60..e5ce4c5857 100644 --- a/docker/nginx/nginx-x509.conf +++ b/docker/nginx/nginx-x509.conf @@ -210,11 +210,14 @@ stream { js_import authorization from /etc/nginx/authorization.js; - # Include single-node or multiple-node (cluster) upstream - include snippets/mqtt-upstream.conf; ssl_verify_client on; include snippets/ssl-client.conf; + # MQTT + # Include single-node or multiple-node (cluster) upstream + include snippets/mqtt-upstream.conf; + + server { listen ${MG_NGINX_MQTT_PORT}; listen [::]:${MG_NGINX_MQTT_PORT}; @@ -226,6 +229,21 @@ stream { proxy_pass mqtt_cluster; } + + # CoAP + include snippets/coap-upstream.conf; + + server { + listen ${MF_NGINX_COAP_PORT}; + listen [::]:${MF_NGINX_COAP_PORT}; + listen ${MF_NGINX_COAPS_PORT} ssl; + listen [::]:${MF_NGINX_COAPS_PORT} ssl; + + include snippets/ssl.conf; + js_preread authorization.authenticate; + + proxy_pass coap_cluster; + } } error_log info.log info; diff --git a/docker/nginx/snippets/coap-upstream.conf b/docker/nginx/snippets/coap-upstream.conf new file mode 100644 index 0000000000..da6c4b87e0 --- /dev/null +++ b/docker/nginx/snippets/coap-upstream.conf @@ -0,0 +1,6 @@ +# Copyright (c) Mainflux +# SPDX-License-Identifier: Apache-2.0 + +upstream coap_cluster { + server coap-adapter:5683; +} diff --git a/go.mod b/go.mod index def75f38b1..54f55964e8 100644 --- a/go.mod +++ b/go.mod @@ -148,7 +148,7 @@ require ( github.com/pelletier/go-toml/v2 v2.2.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/logging v0.2.2 // indirect - github.com/pion/transport/v2 v2.2.4 // indirect + github.com/pion/transport/v3 v3.0.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/plgd-dev/kit/v2 v2.0.0-20211006190727-057b33161b90 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/go.sum b/go.sum index 0e95bfef17..905c68e56b 100644 --- a/go.sum +++ b/go.sum @@ -85,8 +85,6 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/dsnet/golib/memfile v0.0.0-20190531212259-571cdbcff553/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= -github.com/dsnet/golib/memfile v0.0.0-20200723050859-c110804dfa93/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= github.com/dsnet/golib/memfile v1.0.0 h1:J9pUspY2bDCbF9o+YGwcf3uG6MdyITfh/Fk3/CaEiFs= github.com/dsnet/golib/memfile v1.0.0/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= github.com/eclipse/paho.mqtt.golang v1.4.3 h1:2kwcUGn8seMUfWndX0hGbvH8r7crgcJguQNCyp70xik= @@ -128,7 +126,6 @@ github.com/go-kit/kit v0.13.0/go.mod h1:phqEHMMUbyrCFCTgH48JueqrM3md2HcAZ8N3XE4F github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= @@ -151,12 +148,9 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= -github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -347,7 +341,6 @@ github.com/lestrrat-go/jwx/v2 v2.0.21 h1:jAPKupy4uHgrHFEdjVjNkUgoBKtVDgrQPB/h55F github.com/lestrrat-go/jwx/v2 v2.0.21/go.mod h1:09mLW8zto6bWL9GbwnqAli+ArLf+5M33QLQPDggkUWM= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= -github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -415,10 +408,8 @@ github.com/opencontainers/runc v1.1.12/go.mod h1:S+lQwSfncpBha7XTy/5lBwWgm5+y5Ma github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/ory/dockertest/v3 v3.10.0 h1:4K3z2VMe8Woe++invjaTB7VRyQXQy5UY+loujO4aNE4= github.com/ory/dockertest/v3 v3.10.0/go.mod h1:nr57ZbRWMqfsdGdFNLHz5jjNdDb7VVFnzAeW1n5N1Lg= -github.com/panjf2000/ants/v2 v2.4.3/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= @@ -590,8 +581,8 @@ go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/ go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= -go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= @@ -606,8 +597,6 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= @@ -746,7 +735,6 @@ golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -818,11 +806,9 @@ gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/server/coap/coap.go b/internal/server/coap/coap.go index ddf1e6224c..a446eef501 100644 --- a/internal/server/coap/coap.go +++ b/internal/server/coap/coap.go @@ -6,9 +6,11 @@ package coap import ( "context" "crypto/tls" + "crypto/x509" "fmt" "log/slog" "time" + "os" "github.com/absmach/magistrala/internal/server" gocoap "github.com/plgd-dev/go-coap/v2" @@ -24,6 +26,8 @@ type Server struct { handler mux.HandlerFunc } + + var _ server.Server = (*Server)(nil) func New(ctx context.Context, cancel context.CancelFunc, name string, config server.Config, handler mux.HandlerFunc, logger *slog.Logger) server.Server { @@ -51,13 +55,31 @@ func (s *Server) Start() error { if err != nil { return fmt.Errorf("failed to load auth certificates: %w", err) } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{certificate}, + dtlsConfig := &piondtls.Config{ + Certificates: []tls.Certificate{certificate}, + ExtendedMasterSecret: piondtls.RequireExtendedMasterSecret, + ClientAuth: piondtls.RequireAndVerifyClientCert, + ConnectContextMaker: func() (context.Context, func()) { + return context.WithTimeout(s.Ctx, 30*time.Second) + }, + } + clientCA, err := loadCertFile(s.Config.ClientCAFile) + if err != nil { + return fmt.Errorf("failed to load client ca file: %w", err) + } + if len(clientCA) > 0 { + if dtlsConfig.ClientCAs == nil { + dtlsConfig.ClientCAs = x509.NewCertPool() + } + if !dtlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) { + return fmt.Errorf("failed to append client ca to tls.Config") + } } go func() { - errCh <- gocoap.ListenAndServeTCPTLS("udp", s.Address, tlsConfig, s.handler) + errCh <- gocoap.ListenAndServeDTLS("udp", s.Address, dtlsConfig, s.handler) }() + default: s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address)) go func() { @@ -84,3 +106,11 @@ func (s *Server) Stop() error { s.Logger.Info(fmt.Sprintf("%s service shutdown of http at %s", s.Name, s.Address)) return nil } + +func loadCertFile(certFile string) ([]byte, error) { + if certFile != "" { + return os.ReadFile(certFile) + } + return []byte{}, nil +} + diff --git a/vendor/github.com/opencontainers/runc/libcontainer/user/user.go b/vendor/github.com/opencontainers/runc/libcontainer/user/user.go new file mode 100644 index 0000000000..a1e216683d --- /dev/null +++ b/vendor/github.com/opencontainers/runc/libcontainer/user/user.go @@ -0,0 +1,605 @@ +package user + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "strconv" + "strings" +) + +const ( + minID = 0 + maxID = 1<<31 - 1 // for 32-bit systems compatibility +) + +var ( + // ErrNoPasswdEntries is returned if no matching entries were found in /etc/group. + ErrNoPasswdEntries = errors.New("no matching entries in passwd file") + // ErrNoGroupEntries is returned if no matching entries were found in /etc/passwd. + ErrNoGroupEntries = errors.New("no matching entries in group file") + // ErrRange is returned if a UID or GID is outside of the valid range. + ErrRange = fmt.Errorf("uids and gids must be in range %d-%d", minID, maxID) +) + +type User struct { + Name string + Pass string + Uid int + Gid int + Gecos string + Home string + Shell string +} + +type Group struct { + Name string + Pass string + Gid int + List []string +} + +// SubID represents an entry in /etc/sub{u,g}id +type SubID struct { + Name string + SubID int64 + Count int64 +} + +// IDMap represents an entry in /proc/PID/{u,g}id_map +type IDMap struct { + ID int64 + ParentID int64 + Count int64 +} + +func parseLine(line []byte, v ...interface{}) { + parseParts(bytes.Split(line, []byte(":")), v...) +} + +func parseParts(parts [][]byte, v ...interface{}) { + if len(parts) == 0 { + return + } + + for i, p := range parts { + // Ignore cases where we don't have enough fields to populate the arguments. + // Some configuration files like to misbehave. + if len(v) <= i { + break + } + + // Use the type of the argument to figure out how to parse it, scanf() style. + // This is legit. + switch e := v[i].(type) { + case *string: + *e = string(p) + case *int: + // "numbers", with conversion errors ignored because of some misbehaving configuration files. + *e, _ = strconv.Atoi(string(p)) + case *int64: + *e, _ = strconv.ParseInt(string(p), 10, 64) + case *[]string: + // Comma-separated lists. + if len(p) != 0 { + *e = strings.Split(string(p), ",") + } else { + *e = []string{} + } + default: + // Someone goof'd when writing code using this function. Scream so they can hear us. + panic(fmt.Sprintf("parseLine only accepts {*string, *int, *int64, *[]string} as arguments! %#v is not a pointer!", e)) + } + } +} + +func ParsePasswdFile(path string) ([]User, error) { + passwd, err := os.Open(path) + if err != nil { + return nil, err + } + defer passwd.Close() + return ParsePasswd(passwd) +} + +func ParsePasswd(passwd io.Reader) ([]User, error) { + return ParsePasswdFilter(passwd, nil) +} + +func ParsePasswdFileFilter(path string, filter func(User) bool) ([]User, error) { + passwd, err := os.Open(path) + if err != nil { + return nil, err + } + defer passwd.Close() + return ParsePasswdFilter(passwd, filter) +} + +func ParsePasswdFilter(r io.Reader, filter func(User) bool) ([]User, error) { + if r == nil { + return nil, errors.New("nil source for passwd-formatted data") + } + + var ( + s = bufio.NewScanner(r) + out = []User{} + ) + + for s.Scan() { + line := bytes.TrimSpace(s.Bytes()) + if len(line) == 0 { + continue + } + + // see: man 5 passwd + // name:password:UID:GID:GECOS:directory:shell + // Name:Pass:Uid:Gid:Gecos:Home:Shell + // root:x:0:0:root:/root:/bin/bash + // adm:x:3:4:adm:/var/adm:/bin/false + p := User{} + parseLine(line, &p.Name, &p.Pass, &p.Uid, &p.Gid, &p.Gecos, &p.Home, &p.Shell) + + if filter == nil || filter(p) { + out = append(out, p) + } + } + if err := s.Err(); err != nil { + return nil, err + } + + return out, nil +} + +func ParseGroupFile(path string) ([]Group, error) { + group, err := os.Open(path) + if err != nil { + return nil, err + } + + defer group.Close() + return ParseGroup(group) +} + +func ParseGroup(group io.Reader) ([]Group, error) { + return ParseGroupFilter(group, nil) +} + +func ParseGroupFileFilter(path string, filter func(Group) bool) ([]Group, error) { + group, err := os.Open(path) + if err != nil { + return nil, err + } + defer group.Close() + return ParseGroupFilter(group, filter) +} + +func ParseGroupFilter(r io.Reader, filter func(Group) bool) ([]Group, error) { + if r == nil { + return nil, errors.New("nil source for group-formatted data") + } + rd := bufio.NewReader(r) + out := []Group{} + + // Read the file line-by-line. + for { + var ( + isPrefix bool + wholeLine []byte + err error + ) + + // Read the next line. We do so in chunks (as much as reader's + // buffer is able to keep), check if we read enough columns + // already on each step and store final result in wholeLine. + for { + var line []byte + line, isPrefix, err = rd.ReadLine() + + if err != nil { + // We should return no error if EOF is reached + // without a match. + if err == io.EOF { //nolint:errorlint // comparison with io.EOF is legit, https://github.com/polyfloyd/go-errorlint/pull/12 + err = nil + } + return out, err + } + + // Simple common case: line is short enough to fit in a + // single reader's buffer. + if !isPrefix && len(wholeLine) == 0 { + wholeLine = line + break + } + + wholeLine = append(wholeLine, line...) + + // Check if we read the whole line already. + if !isPrefix { + break + } + } + + // There's no spec for /etc/passwd or /etc/group, but we try to follow + // the same rules as the glibc parser, which allows comments and blank + // space at the beginning of a line. + wholeLine = bytes.TrimSpace(wholeLine) + if len(wholeLine) == 0 || wholeLine[0] == '#' { + continue + } + + // see: man 5 group + // group_name:password:GID:user_list + // Name:Pass:Gid:List + // root:x:0:root + // adm:x:4:root,adm,daemon + p := Group{} + parseLine(wholeLine, &p.Name, &p.Pass, &p.Gid, &p.List) + + if filter == nil || filter(p) { + out = append(out, p) + } + } +} + +type ExecUser struct { + Uid int + Gid int + Sgids []int + Home string +} + +// GetExecUserPath is a wrapper for GetExecUser. It reads data from each of the +// given file paths and uses that data as the arguments to GetExecUser. If the +// files cannot be opened for any reason, the error is ignored and a nil +// io.Reader is passed instead. +func GetExecUserPath(userSpec string, defaults *ExecUser, passwdPath, groupPath string) (*ExecUser, error) { + var passwd, group io.Reader + + if passwdFile, err := os.Open(passwdPath); err == nil { + passwd = passwdFile + defer passwdFile.Close() + } + + if groupFile, err := os.Open(groupPath); err == nil { + group = groupFile + defer groupFile.Close() + } + + return GetExecUser(userSpec, defaults, passwd, group) +} + +// GetExecUser parses a user specification string (using the passwd and group +// readers as sources for /etc/passwd and /etc/group data, respectively). In +// the case of blank fields or missing data from the sources, the values in +// defaults is used. +// +// GetExecUser will return an error if a user or group literal could not be +// found in any entry in passwd and group respectively. +// +// Examples of valid user specifications are: +// - "" +// - "user" +// - "uid" +// - "user:group" +// - "uid:gid +// - "user:gid" +// - "uid:group" +// +// It should be noted that if you specify a numeric user or group id, they will +// not be evaluated as usernames (only the metadata will be filled). So attempting +// to parse a user with user.Name = "1337" will produce the user with a UID of +// 1337. +func GetExecUser(userSpec string, defaults *ExecUser, passwd, group io.Reader) (*ExecUser, error) { + if defaults == nil { + defaults = new(ExecUser) + } + + // Copy over defaults. + user := &ExecUser{ + Uid: defaults.Uid, + Gid: defaults.Gid, + Sgids: defaults.Sgids, + Home: defaults.Home, + } + + // Sgids slice *cannot* be nil. + if user.Sgids == nil { + user.Sgids = []int{} + } + + // Allow for userArg to have either "user" syntax, or optionally "user:group" syntax + var userArg, groupArg string + parseLine([]byte(userSpec), &userArg, &groupArg) + + // Convert userArg and groupArg to be numeric, so we don't have to execute + // Atoi *twice* for each iteration over lines. + uidArg, uidErr := strconv.Atoi(userArg) + gidArg, gidErr := strconv.Atoi(groupArg) + + // Find the matching user. + users, err := ParsePasswdFilter(passwd, func(u User) bool { + if userArg == "" { + // Default to current state of the user. + return u.Uid == user.Uid + } + + if uidErr == nil { + // If the userArg is numeric, always treat it as a UID. + return uidArg == u.Uid + } + + return u.Name == userArg + }) + + // If we can't find the user, we have to bail. + if err != nil && passwd != nil { + if userArg == "" { + userArg = strconv.Itoa(user.Uid) + } + return nil, fmt.Errorf("unable to find user %s: %w", userArg, err) + } + + var matchedUserName string + if len(users) > 0 { + // First match wins, even if there's more than one matching entry. + matchedUserName = users[0].Name + user.Uid = users[0].Uid + user.Gid = users[0].Gid + user.Home = users[0].Home + } else if userArg != "" { + // If we can't find a user with the given username, the only other valid + // option is if it's a numeric username with no associated entry in passwd. + + if uidErr != nil { + // Not numeric. + return nil, fmt.Errorf("unable to find user %s: %w", userArg, ErrNoPasswdEntries) + } + user.Uid = uidArg + + // Must be inside valid uid range. + if user.Uid < minID || user.Uid > maxID { + return nil, ErrRange + } + + // Okay, so it's numeric. We can just roll with this. + } + + // On to the groups. If we matched a username, we need to do this because of + // the supplementary group IDs. + if groupArg != "" || matchedUserName != "" { + groups, err := ParseGroupFilter(group, func(g Group) bool { + // If the group argument isn't explicit, we'll just search for it. + if groupArg == "" { + // Check if user is a member of this group. + for _, u := range g.List { + if u == matchedUserName { + return true + } + } + return false + } + + if gidErr == nil { + // If the groupArg is numeric, always treat it as a GID. + return gidArg == g.Gid + } + + return g.Name == groupArg + }) + if err != nil && group != nil { + return nil, fmt.Errorf("unable to find groups for spec %v: %w", matchedUserName, err) + } + + // Only start modifying user.Gid if it is in explicit form. + if groupArg != "" { + if len(groups) > 0 { + // First match wins, even if there's more than one matching entry. + user.Gid = groups[0].Gid + } else { + // If we can't find a group with the given name, the only other valid + // option is if it's a numeric group name with no associated entry in group. + + if gidErr != nil { + // Not numeric. + return nil, fmt.Errorf("unable to find group %s: %w", groupArg, ErrNoGroupEntries) + } + user.Gid = gidArg + + // Must be inside valid gid range. + if user.Gid < minID || user.Gid > maxID { + return nil, ErrRange + } + + // Okay, so it's numeric. We can just roll with this. + } + } else if len(groups) > 0 { + // Supplementary group ids only make sense if in the implicit form. + user.Sgids = make([]int, len(groups)) + for i, group := range groups { + user.Sgids[i] = group.Gid + } + } + } + + return user, nil +} + +// GetAdditionalGroups looks up a list of groups by name or group id +// against the given /etc/group formatted data. If a group name cannot +// be found, an error will be returned. If a group id cannot be found, +// or the given group data is nil, the id will be returned as-is +// provided it is in the legal range. +func GetAdditionalGroups(additionalGroups []string, group io.Reader) ([]int, error) { + groups := []Group{} + if group != nil { + var err error + groups, err = ParseGroupFilter(group, func(g Group) bool { + for _, ag := range additionalGroups { + if g.Name == ag || strconv.Itoa(g.Gid) == ag { + return true + } + } + return false + }) + if err != nil { + return nil, fmt.Errorf("Unable to find additional groups %v: %w", additionalGroups, err) + } + } + + gidMap := make(map[int]struct{}) + for _, ag := range additionalGroups { + var found bool + for _, g := range groups { + // if we found a matched group either by name or gid, take the + // first matched as correct + if g.Name == ag || strconv.Itoa(g.Gid) == ag { + if _, ok := gidMap[g.Gid]; !ok { + gidMap[g.Gid] = struct{}{} + found = true + break + } + } + } + // we asked for a group but didn't find it. let's check to see + // if we wanted a numeric group + if !found { + gid, err := strconv.ParseInt(ag, 10, 64) + if err != nil { + // Not a numeric ID either. + return nil, fmt.Errorf("Unable to find group %s: %w", ag, ErrNoGroupEntries) + } + // Ensure gid is inside gid range. + if gid < minID || gid > maxID { + return nil, ErrRange + } + gidMap[int(gid)] = struct{}{} + } + } + gids := []int{} + for gid := range gidMap { + gids = append(gids, gid) + } + return gids, nil +} + +// GetAdditionalGroupsPath is a wrapper around GetAdditionalGroups +// that opens the groupPath given and gives it as an argument to +// GetAdditionalGroups. +func GetAdditionalGroupsPath(additionalGroups []string, groupPath string) ([]int, error) { + var group io.Reader + + if groupFile, err := os.Open(groupPath); err == nil { + group = groupFile + defer groupFile.Close() + } + return GetAdditionalGroups(additionalGroups, group) +} + +func ParseSubIDFile(path string) ([]SubID, error) { + subid, err := os.Open(path) + if err != nil { + return nil, err + } + defer subid.Close() + return ParseSubID(subid) +} + +func ParseSubID(subid io.Reader) ([]SubID, error) { + return ParseSubIDFilter(subid, nil) +} + +func ParseSubIDFileFilter(path string, filter func(SubID) bool) ([]SubID, error) { + subid, err := os.Open(path) + if err != nil { + return nil, err + } + defer subid.Close() + return ParseSubIDFilter(subid, filter) +} + +func ParseSubIDFilter(r io.Reader, filter func(SubID) bool) ([]SubID, error) { + if r == nil { + return nil, errors.New("nil source for subid-formatted data") + } + + var ( + s = bufio.NewScanner(r) + out = []SubID{} + ) + + for s.Scan() { + line := bytes.TrimSpace(s.Bytes()) + if len(line) == 0 { + continue + } + + // see: man 5 subuid + p := SubID{} + parseLine(line, &p.Name, &p.SubID, &p.Count) + + if filter == nil || filter(p) { + out = append(out, p) + } + } + if err := s.Err(); err != nil { + return nil, err + } + + return out, nil +} + +func ParseIDMapFile(path string) ([]IDMap, error) { + r, err := os.Open(path) + if err != nil { + return nil, err + } + defer r.Close() + return ParseIDMap(r) +} + +func ParseIDMap(r io.Reader) ([]IDMap, error) { + return ParseIDMapFilter(r, nil) +} + +func ParseIDMapFileFilter(path string, filter func(IDMap) bool) ([]IDMap, error) { + r, err := os.Open(path) + if err != nil { + return nil, err + } + defer r.Close() + return ParseIDMapFilter(r, filter) +} + +func ParseIDMapFilter(r io.Reader, filter func(IDMap) bool) ([]IDMap, error) { + if r == nil { + return nil, errors.New("nil source for idmap-formatted data") + } + + var ( + s = bufio.NewScanner(r) + out = []IDMap{} + ) + + for s.Scan() { + line := bytes.TrimSpace(s.Bytes()) + if len(line) == 0 { + continue + } + + // see: man 7 user_namespaces + p := IDMap{} + parseParts(bytes.Fields(line), &p.ID, &p.ParentID, &p.Count) + + if filter == nil || filter(p) { + out = append(out, p) + } + } + if err := s.Err(); err != nil { + return nil, err + } + + return out, nil +} diff --git a/vendor/github.com/pion/dtls/v2/AUTHORS.txt b/vendor/github.com/pion/dtls/v2/AUTHORS.txt new file mode 100644 index 0000000000..fbaf97711d --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/AUTHORS.txt @@ -0,0 +1,60 @@ +# Thank you to everyone that made Pion possible. If you are interested in contributing +# we would love to have you https://github.com/pion/webrtc/wiki/Contributing +# +# This file is auto generated, using git to list all individuals contributors. +# see https://github.com/pion/.goassets/blob/master/scripts/generate-authors.sh for the scripting +Aleksandr Razumov +alvarowolfx +Arlo Breault +Atsushi Watanabe +backkem +bjdgyc +boks1971 +Bragadeesh +Carson Hoffman +Cecylia Bocovich +Chris Hiszpanski +cnderrauber +Daniel Mangum +Daniele Sluijters +folbrich +Hayden James +Hugo Arregui +Hugo Arregui +igolaizola <11333576+igolaizola@users.noreply.github.com> +Jeffrey Stoke +Jeroen de Bruijn +Jeroen de Bruijn +Jim Wert +jinleileiking +Jozef Kralik +Julien Salleyron +Juliusz Chroboczek +Kegan Dougal +Kevin Wang +Lander Noterman +Len +Lukas Lihotzki +ManuelBk <26275612+ManuelBk@users.noreply.github.com> +Michael Zabka +Michiel De Backker +Rachel Chen +Robert Eperjesi +Ryan Gordon +Sam Lancia +Sean +Sean DuBois +Sean DuBois +Sean DuBois +Sean DuBois +Shelikhoo +Stefan Tatschner +Steffen Vogel +Vadim +Vadim Filimonov +wmiao +ZHENK +吕海涛 + +# List of contributors not appearing in Git history + diff --git a/vendor/github.com/pion/dtls/v2/cipher_suite.go b/vendor/github.com/pion/dtls/v2/cipher_suite.go new file mode 100644 index 0000000000..6f7015c026 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/cipher_suite.go @@ -0,0 +1,276 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/tls" + "fmt" + "hash" + + "github.com/pion/dtls/v2/internal/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// CipherSuiteID is an ID for our supported CipherSuites +type CipherSuiteID = ciphersuite.ID + +// Supported Cipher Suites +const ( + // AES-128-CCM + TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:revive,stylecheck + TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:revive,stylecheck + + // AES-128-GCM-SHA256 + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck + + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck + + // AES-256-CBC-SHA + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck + + TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:revive,stylecheck + TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck + TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck + + TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck +) + +// CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite +type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType + +// AuthenticationType Enums +const ( + CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate + CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey + CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous +) + +// CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite +type CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithm + +// CipherSuiteKeyExchangeAlgorithm Bitmask +const ( + CipherSuiteKeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmNone + CipherSuiteKeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmPsk + CipherSuiteKeyExchangeAlgorithmEcdhe CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmEcdhe +) + +var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14 + +// CipherSuite is an interface that all DTLS CipherSuites must satisfy +type CipherSuite interface { + // String of CipherSuite, only used for logging + String() string + + // ID of CipherSuite. + ID() CipherSuiteID + + // What type of Certificate does this CipherSuite use + CertificateType() clientcertificate.Type + + // What Hash function is used during verification + HashFunc() func() hash.Hash + + // AuthenticationType controls what authentication method is using during the handshake + AuthenticationType() CipherSuiteAuthenticationType + + // KeyExchangeAlgorithm controls what exchange algorithm is using during the handshake + KeyExchangeAlgorithm() CipherSuiteKeyExchangeAlgorithm + + // ECC (Elliptic Curve Cryptography) determines whether ECC extesions will be send during handshake. + // https://datatracker.ietf.org/doc/html/rfc4492#page-10 + ECC() bool + + // Called when keying material has been generated, should initialize the internal cipher + Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error + IsInitialized() bool + Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) + Decrypt(h recordlayer.Header, in []byte) ([]byte, error) +} + +// CipherSuiteName provides the same functionality as tls.CipherSuiteName +// that appeared first in Go 1.14. +// +// Our implementation differs slightly in that it takes in a CiperSuiteID, +// like the rest of our library, instead of a uint16 like crypto/tls. +func CipherSuiteName(id CipherSuiteID) string { + suite := cipherSuiteForID(id, nil) + if suite != nil { + return suite.String() + } + return fmt.Sprintf("0x%04X", uint16(id)) +} + +// Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml +// A cipherSuite is a specific combination of key agreement, cipher and MAC +// function. +func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { + switch id { //nolint:exhaustive + case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: + return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm() + case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: + return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8() + case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + return &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} + case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{} + case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: + return &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{} + case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + return &ciphersuite.TLSEcdheRsaWithAes256CbcSha{} + case TLS_PSK_WITH_AES_128_CCM: + return ciphersuite.NewTLSPskWithAes128Ccm() + case TLS_PSK_WITH_AES_128_CCM_8: + return ciphersuite.NewTLSPskWithAes128Ccm8() + case TLS_PSK_WITH_AES_256_CCM_8: + return ciphersuite.NewTLSPskWithAes256Ccm8() + case TLS_PSK_WITH_AES_128_GCM_SHA256: + return &ciphersuite.TLSPskWithAes128GcmSha256{} + case TLS_PSK_WITH_AES_128_CBC_SHA256: + return &ciphersuite.TLSPskWithAes128CbcSha256{} + case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: + return &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{} + case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: + return &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{} + case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: + return ciphersuite.NewTLSEcdhePskWithAes128CbcSha256() + } + + if customCiphers != nil { + for _, c := range customCiphers() { + if c.ID() == id { + return c + } + } + } + + return nil +} + +// CipherSuites we support in order of preference +func defaultCipherSuites() []CipherSuite { + return []CipherSuite{ + &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, + &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, + &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, + &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, + } +} + +func allCipherSuites() []CipherSuite { + return []CipherSuite{ + ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm(), + ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8(), + &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, + &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, + ciphersuite.NewTLSPskWithAes128Ccm(), + ciphersuite.NewTLSPskWithAes128Ccm8(), + ciphersuite.NewTLSPskWithAes256Ccm8(), + &ciphersuite.TLSPskWithAes128GcmSha256{}, + &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, + &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, + } +} + +func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 { + rtrn := []uint16{} + for _, c := range cipherSuites { + rtrn = append(rtrn, uint16(c.ID())) + } + return rtrn +} + +func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool) ([]CipherSuite, error) { + cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) { + cipherSuites := []CipherSuite{} + for _, id := range ids { + c := cipherSuiteForID(id, nil) + if c == nil { + return nil, &invalidCipherSuiteError{id} + } + cipherSuites = append(cipherSuites, c) + } + return cipherSuites, nil + } + + var ( + cipherSuites []CipherSuite + err error + i int + ) + if userSelectedSuites != nil { + cipherSuites, err = cipherSuitesForIDs(userSelectedSuites) + if err != nil { + return nil, err + } + } else { + cipherSuites = defaultCipherSuites() + } + + // Put CustomCipherSuites before ID selected suites + if customCipherSuites != nil { + cipherSuites = append(customCipherSuites(), cipherSuites...) + } + + var foundCertificateSuite, foundPSKSuite, foundAnonymousSuite bool + for _, c := range cipherSuites { + switch { + case includeCertificateSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: + foundCertificateSuite = true + case includePSKSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey: + foundPSKSuite = true + case c.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous: + foundAnonymousSuite = true + default: + continue + } + cipherSuites[i] = c + i++ + } + + switch { + case includeCertificateSuites && !foundCertificateSuite && !foundAnonymousSuite: + return nil, errNoAvailableCertificateCipherSuite + case includePSKSuites && !foundPSKSuite: + return nil, errNoAvailablePSKCipherSuite + case i == 0: + return nil, errNoAvailableCipherSuites + } + + return cipherSuites[:i], nil +} + +func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []CipherSuite) []CipherSuite { + if cert == nil || cert.PrivateKey == nil { + return cipherSuites + } + var certType clientcertificate.Type + switch cert.PrivateKey.(type) { + case ed25519.PrivateKey, *ecdsa.PrivateKey: + certType = clientcertificate.ECDSASign + case *rsa.PrivateKey: + certType = clientcertificate.RSASign + } + + filtered := []CipherSuite{} + for _, c := range cipherSuites { + if c.AuthenticationType() != CipherSuiteAuthenticationTypeCertificate || certType == c.CertificateType() { + filtered = append(filtered, c) + } + } + return filtered +} diff --git a/vendor/github.com/pion/dtls/v2/config.go b/vendor/github.com/pion/dtls/v2/config.go new file mode 100644 index 0000000000..604a4d5757 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/config.go @@ -0,0 +1,273 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "io" + "time" + + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/logging" +) + +const keyLogLabelTLS12 = "CLIENT_RANDOM" + +// Config is used to configure a DTLS client or server. +// After a Config is passed to a DTLS function it must not be modified. +type Config struct { + // Certificates contains certificate chain to present to the other side of the connection. + // Server MUST set this if PSK is non-nil + // client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil + Certificates []tls.Certificate + + // CipherSuites is a list of supported cipher suites. + // If CipherSuites is nil, a default list is used + CipherSuites []CipherSuiteID + + // CustomCipherSuites is a list of CipherSuites that can be + // provided by the user. This allow users to user Ciphers that are reserved + // for private usage. + CustomCipherSuites func() []CipherSuite + + // SignatureSchemes contains the signature and hash schemes that the peer requests to verify. + SignatureSchemes []tls.SignatureScheme + + // SRTPProtectionProfiles are the supported protection profiles + // Clients will send this via use_srtp and assert that the server properly responds + // Servers will assert that clients send one of these profiles and will respond as needed + SRTPProtectionProfiles []SRTPProtectionProfile + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // RequireExtendedMasterSecret determines if the "Extended Master Secret" extension + // should be disabled, requested, or required (default requested). + ExtendedMasterSecret ExtendedMasterSecretType + + // FlightInterval controls how often we send outbound handshake messages + // defaults to time.Second + FlightInterval time.Duration + + // PSK sets the pre-shared key used by this DTLS connection + // If PSK is non-nil only PSK CipherSuites will be used + PSK PSKCallback + PSKIdentityHint []byte + + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + + // InsecureHashes allows the use of hashing algorithms that are known + // to be vulnerable. + InsecureHashes bool + + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a client or server. It + // receives the certificate provided by the peer and also a flag + // that tells if normal verification has succeedded. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify, or (for a server) when ClientAuth is + // RequestClientCert or RequireAnyClientCert, then this callback will + // be considered but the verifiedChains will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + + // VerifyConnection, if not nil, is called after normal certificate + // verification/PSK and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(*State) error + + // RootCAs defines the set of root certificate authorities + // that one peer uses when verifying the other peer's certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. + ServerName string + + LoggerFactory logging.LoggerFactory + + // ConnectContextMaker is a function to make a context used in Dial(), + // Client(), Server(), and Accept(). If nil, the default ConnectContextMaker + // is used. It can be implemented as following. + // + // func ConnectContextMaker() (context.Context, func()) { + // return context.WithTimeout(context.Background(), 30*time.Second) + // } + ConnectContextMaker func() (context.Context, func()) + + // MTU is the length at which handshake messages will be fragmented to + // fit within the maximum transmission unit (default is 1200 bytes) + MTU int + + // ReplayProtectionWindow is the size of the replay attack protection window. + // Duplication of the sequence number is checked in this window size. + // Packet with sequence number older than this value compared to the latest + // accepted packet will be discarded. (default is 64) + ReplayProtectionWindow int + + // KeyLogWriter optionally specifies a destination for TLS master secrets + // in NSS key log format that can be used to allow external programs + // such as Wireshark to decrypt TLS connections. + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. + // Use of KeyLogWriter compromises security and should only be + // used for debugging. + KeyLogWriter io.Writer + + // SessionStore is the container to store session for resumption. + SessionStore SessionStore + + // List of application protocols the peer supports, for ALPN + SupportedProtocols []string + + // List of Elliptic Curves to use + // + // If an ECC ciphersuite is configured and EllipticCurves is empty + // it will default to X25519, P-256, P-384 in this specific order. + EllipticCurves []elliptic.Curve + + // GetCertificate returns a Certificate based on the given + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // best element of Certificates will be used. + GetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) + + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + GetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) + + // InsecureSkipVerifyHello, if true and when acting as server, allow client to + // skip hello verify phase and receive ServerHello after initial ClientHello. + // This have implication on DoS attack resistance. + InsecureSkipVerifyHello bool + + // ConnectionIDGenerator generates connection identifiers that should be + // sent by the remote party if it supports the DTLS Connection Identifier + // extension, as determined during the handshake. Generated connection + // identifiers must always have the same length. Returning a zero-length + // connection identifier indicates that the local party supports sending + // connection identifiers but does not require the remote party to send + // them. A nil ConnectionIDGenerator indicates that connection identifiers + // are not supported. + // https://datatracker.ietf.org/doc/html/rfc9146 + ConnectionIDGenerator func() []byte + + // PaddingLengthGenerator generates the number of padding bytes used to + // inflate ciphertext size in order to obscure content size from observers. + // The length of the content is passed to the generator such that both + // deterministic and random padding schemes can be applied while not + // exceeding maximum record size. + // If no PaddingLengthGenerator is specified, padding will not be applied. + // https://datatracker.ietf.org/doc/html/rfc9146#section-4 + PaddingLengthGenerator func(uint) uint +} + +func defaultConnectContextMaker() (context.Context, func()) { + return context.WithTimeout(context.Background(), 30*time.Second) +} + +func (c *Config) connectContextMaker() (context.Context, func()) { + if c.ConnectContextMaker == nil { + return defaultConnectContextMaker() + } + return c.ConnectContextMaker() +} + +func (c *Config) includeCertificateSuites() bool { + return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil +} + +const defaultMTU = 1200 // bytes + +var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals + +// PSKCallback is called once we have the remote's PSKIdentityHint. +// If the remote provided none it will be nil +type PSKCallback func([]byte) ([]byte, error) + +// ClientAuthType declares the policy the server will follow for +// TLS Client Authentication. +type ClientAuthType int + +// ClientAuthType enums +const ( + NoClientCert ClientAuthType = iota + RequestClientCert + RequireAnyClientCert + VerifyClientCertIfGiven + RequireAndVerifyClientCert +) + +// ExtendedMasterSecretType declares the policy the client and server +// will follow for the Extended Master Secret extension +type ExtendedMasterSecretType int + +// ExtendedMasterSecretType enums +const ( + RequestExtendedMasterSecret ExtendedMasterSecretType = iota + RequireExtendedMasterSecret + DisableExtendedMasterSecret +) + +func validateConfig(config *Config) error { + switch { + case config == nil: + return errNoConfigProvided + case config.PSKIdentityHint != nil && config.PSK == nil: + return errIdentityNoPSK + } + + for _, cert := range config.Certificates { + if cert.Certificate == nil { + return errInvalidCertificate + } + if cert.PrivateKey != nil { + switch cert.PrivateKey.(type) { + case ed25519.PrivateKey: + case *ecdsa.PrivateKey: + case *rsa.PrivateKey: + default: + return errInvalidPrivateKey + } + } + } + + _, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + return err +} diff --git a/vendor/github.com/pion/dtls/v2/conn.go b/vendor/github.com/pion/dtls/v2/conn.go new file mode 100644 index 0000000000..9d1da84cb1 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/conn.go @@ -0,0 +1,1178 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pion/dtls/v2/internal/closer" + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/logging" + "github.com/pion/transport/v3/deadline" + "github.com/pion/transport/v3/netctx" + "github.com/pion/transport/v3/replaydetector" +) + +const ( + initialTickerInterval = time.Second + cookieLength = 20 + sessionLength = 32 + defaultNamedCurve = elliptic.X25519 + inboundBufferSize = 8192 + // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 + defaultReplayProtectionWindow = 64 +) + +func invalidKeyingLabels() map[string]bool { + return map[string]bool{ + "client finished": true, + "server finished": true, + "master secret": true, + "key expansion": true, + } +} + +type addrPkt struct { + rAddr net.Addr + data []byte +} + +// Conn represents a DTLS connection +type Conn struct { + lock sync.RWMutex // Internal lock (must not be public) + nextConn netctx.PacketConn // Embedded Conn, typically a udpconn we read/write from + fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling + handshakeCache *handshakeCache // caching of handshake messages for verifyData generation + decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read` + rAddr net.Addr + state State // Internal state + + maximumTransmissionUnit int + paddingLengthGenerator func(uint) uint + + handshakeCompletedSuccessfully atomic.Value + + encryptedPackets []addrPkt + + connectionClosedByUser bool + closeLock sync.Mutex + closed *closer.Closer + handshakeLoopsFinished sync.WaitGroup + + readDeadline *deadline.Deadline + writeDeadline *deadline.Deadline + + log logging.LeveledLogger + + reading chan struct{} + handshakeRecv chan chan struct{} + cancelHandshaker func() + cancelHandshakeReader func() + + fsm *handshakeFSM + + replayProtectionWindow uint +} + +func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, initialState *State) (*Conn, error) { + if err := validateConfig(config); err != nil { + return nil, err + } + + if nextConn == nil { + return nil, errNilNextConn + } + + cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + if err != nil { + return nil, err + } + + signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) + if err != nil { + return nil, err + } + + workerInterval := initialTickerInterval + if config.FlightInterval != 0 { + workerInterval = config.FlightInterval + } + + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + + logger := loggerFactory.NewLogger("dtls") + + mtu := config.MTU + if mtu <= 0 { + mtu = defaultMTU + } + + replayProtectionWindow := config.ReplayProtectionWindow + if replayProtectionWindow <= 0 { + replayProtectionWindow = defaultReplayProtectionWindow + } + + paddingLengthGenerator := config.PaddingLengthGenerator + if paddingLengthGenerator == nil { + paddingLengthGenerator = func(uint) uint { return 0 } + } + + c := &Conn{ + rAddr: rAddr, + nextConn: netctx.NewPacketConn(nextConn), + fragmentBuffer: newFragmentBuffer(), + handshakeCache: newHandshakeCache(), + maximumTransmissionUnit: mtu, + paddingLengthGenerator: paddingLengthGenerator, + + decrypted: make(chan interface{}, 1), + log: logger, + + readDeadline: deadline.New(), + writeDeadline: deadline.New(), + + reading: make(chan struct{}, 1), + handshakeRecv: make(chan chan struct{}), + closed: closer.NewCloser(), + cancelHandshaker: func() {}, + + replayProtectionWindow: uint(replayProtectionWindow), + + state: State{ + isClient: isClient, + }, + } + + c.setRemoteEpoch(0) + c.setLocalEpoch(0) + + serverName := config.ServerName + // Do not allow the use of an IP address literal as an SNI value. + // See RFC 6066, Section 3. + if net.ParseIP(serverName) != nil { + serverName = "" + } + + curves := config.EllipticCurves + if len(curves) == 0 { + curves = defaultCurves + } + + hsCfg := &handshakeConfig{ + localPSKCallback: config.PSK, + localPSKIdentityHint: config.PSKIdentityHint, + localCipherSuites: cipherSuites, + localSignatureSchemes: signatureSchemes, + extendedMasterSecret: config.ExtendedMasterSecret, + localSRTPProtectionProfiles: config.SRTPProtectionProfiles, + serverName: serverName, + supportedProtocols: config.SupportedProtocols, + clientAuth: config.ClientAuth, + localCertificates: config.Certificates, + insecureSkipVerify: config.InsecureSkipVerify, + verifyPeerCertificate: config.VerifyPeerCertificate, + verifyConnection: config.VerifyConnection, + rootCAs: config.RootCAs, + clientCAs: config.ClientCAs, + customCipherSuites: config.CustomCipherSuites, + retransmitInterval: workerInterval, + log: logger, + initialEpoch: 0, + keyLogWriter: config.KeyLogWriter, + sessionStore: config.SessionStore, + ellipticCurves: curves, + localGetCertificate: config.GetCertificate, + localGetClientCertificate: config.GetClientCertificate, + insecureSkipHelloVerify: config.InsecureSkipVerifyHello, + connectionIDGenerator: config.ConnectionIDGenerator, + } + + // rfc5246#section-7.4.3 + // In addition, the hash and signature algorithms MUST be compatible + // with the key in the server's end-entity certificate. + if !isClient { + cert, err := hsCfg.getCertificate(&ClientHelloInfo{}) + if err != nil && !errors.Is(err, errNoCertificates) { + return nil, err + } + hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites) + } + + var initialFlight flightVal + var initialFSMState handshakeState + + if initialState != nil { + if c.state.isClient { + initialFlight = flight5 + } else { + initialFlight = flight6 + } + initialFSMState = handshakeFinished + + c.state = *initialState + } else { + if c.state.isClient { + initialFlight = flight1 + } else { + initialFlight = flight0 + } + initialFSMState = handshakePreparing + } + // Do handshake + if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { + return nil, err + } + + c.log.Trace("Handshake Completed") + + return c, nil +} + +// Dial connects to the given network address and establishes a DTLS connection on top. +// Connection handshake will timeout using ConnectContextMaker in the Config. +// If you want to specify the timeout duration, use DialWithContext() instead. +func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { + ctx, cancel := config.connectContextMaker() + defer cancel() + + return DialWithContext(ctx, network, rAddr, config) +} + +// Client establishes a DTLS connection over an existing connection. +// Connection handshake will timeout using ConnectContextMaker in the Config. +// If you want to specify the timeout duration, use ClientWithContext() instead. +func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + ctx, cancel := config.connectContextMaker() + defer cancel() + + return ClientWithContext(ctx, conn, rAddr, config) +} + +// Server listens for incoming DTLS connections. +// Connection handshake will timeout using ConnectContextMaker in the Config. +// If you want to specify the timeout duration, use ServerWithContext() instead. +func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + ctx, cancel := config.connectContextMaker() + defer cancel() + + return ServerWithContext(ctx, conn, rAddr, config) +} + +// DialWithContext connects to the given network address and establishes a DTLS +// connection on top. +func DialWithContext(ctx context.Context, network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { + // net.ListenUDP is used rather than net.DialUDP as the latter prevents the + // use of net.PacketConn.WriteTo. + // https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115 + pConn, err := net.ListenUDP(network, nil) + if err != nil { + return nil, err + } + + return ClientWithContext(ctx, pConn, rAddr, config) +} + +// ClientWithContext establishes a DTLS connection over an existing connection. +func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + switch { + case config == nil: + return nil, errNoConfigProvided + case config.PSK != nil && config.PSKIdentityHint == nil: + return nil, errPSKAndIdentityMustBeSetForClient + } + + return createConn(ctx, conn, rAddr, config, true, nil) +} + +// ServerWithContext listens for incoming DTLS connections. +func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + if config == nil { + return nil, errNoConfigProvided + } + + return createConn(ctx, conn, rAddr, config, false, nil) +} + +// Read reads data from the connection. +func (c *Conn) Read(p []byte) (n int, err error) { + if !c.isHandshakeCompletedSuccessfully() { + return 0, errHandshakeInProgress + } + + select { + case <-c.readDeadline.Done(): + return 0, errDeadlineExceeded + default: + } + + for { + select { + case <-c.readDeadline.Done(): + return 0, errDeadlineExceeded + case out, ok := <-c.decrypted: + if !ok { + return 0, io.EOF + } + switch val := out.(type) { + case ([]byte): + if len(p) < len(val) { + return 0, errBufferTooSmall + } + copy(p, val) + return len(val), nil + case (error): + return 0, val + } + } + } +} + +// Write writes len(p) bytes from p to the DTLS connection +func (c *Conn) Write(p []byte) (int, error) { + if c.isConnectionClosed() { + return 0, ErrConnClosed + } + + select { + case <-c.writeDeadline.Done(): + return 0, errDeadlineExceeded + default: + } + + if !c.isHandshakeCompletedSuccessfully() { + return 0, errHandshakeInProgress + } + + return len(p), c.writePackets(c.writeDeadline, []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Epoch: c.state.getLocalEpoch(), + Version: protocol.Version1_2, + }, + Content: &protocol.ApplicationData{ + Data: p, + }, + }, + shouldWrapCID: len(c.state.remoteConnectionID) > 0, + shouldEncrypt: true, + }, + }) +} + +// Close closes the connection. +func (c *Conn) Close() error { + err := c.close(true) //nolint:contextcheck + c.handshakeLoopsFinished.Wait() + return err +} + +// ConnectionState returns basic DTLS details about the connection. +// Note that this replaced the `Export` function of v1. +func (c *Conn) ConnectionState() State { + c.lock.RLock() + defer c.lock.RUnlock() + return *c.state.clone() +} + +// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile +func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + + if c.state.srtpProtectionProfile == 0 { + return 0, false + } + + return c.state.srtpProtectionProfile, true +} + +func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { + c.lock.Lock() + defer c.lock.Unlock() + + var rawPackets [][]byte + + for _, p := range pkts { + if h, ok := p.record.Content.(*handshake.Handshake); ok { + handshakeRaw, err := p.record.Marshal() + if err != nil { + return err + } + + c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", + srvCliStr(c.state.isClient), h.Header.Type.String(), + p.record.Header.Epoch, h.Header.MessageSequence) + + c.handshakeCache.push(handshakeRaw[recordlayer.FixedHeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) + + rawHandshakePackets, err := c.processHandshakePacket(p, h) + if err != nil { + return err + } + rawPackets = append(rawPackets, rawHandshakePackets...) + } else { + rawPacket, err := c.processPacket(p) + if err != nil { + return err + } + rawPackets = append(rawPackets, rawPacket) + } + } + if len(rawPackets) == 0 { + return nil + } + compactedRawPackets := c.compactRawPackets(rawPackets) + + for _, compactedRawPackets := range compactedRawPackets { + if _, err := c.nextConn.WriteToContext(ctx, compactedRawPackets, c.rAddr); err != nil { + return netError(err) + } + } + + return nil +} + +func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { + // avoid a useless copy in the common case + if len(rawPackets) == 1 { + return rawPackets + } + + combinedRawPackets := make([][]byte, 0) + currentCombinedRawPacket := make([]byte, 0) + + for _, rawPacket := range rawPackets { + if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit { + combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) + currentCombinedRawPacket = []byte{} + } + currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...) + } + + combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) + + return combinedRawPackets +} + +func (c *Conn) processPacket(p *packet) ([]byte, error) { + epoch := p.record.Header.Epoch + for len(c.state.localSequenceNumber) <= int(epoch) { + c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) + } + seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 + if seq > recordlayer.MaxSequenceNumber { + // RFC 6347 Section 4.1.0 + // The implementation must either abandon an association or rehandshake + // prior to allowing the sequence number to wrap. + return nil, errSequenceNumberOverflow + } + p.record.Header.SequenceNumber = seq + + var rawPacket []byte + if p.shouldWrapCID { + // Record must be marshaled to populate fields used in inner plaintext. + if _, err := p.record.Marshal(); err != nil { + return nil, err + } + content, err := p.record.Content.Marshal() + if err != nil { + return nil, err + } + inner := &recordlayer.InnerPlaintext{ + Content: content, + RealType: p.record.Header.ContentType, + } + rawInner, err := inner.Marshal() //nolint:govet + if err != nil { + return nil, err + } + cidHeader := &recordlayer.Header{ + Version: p.record.Header.Version, + ContentType: protocol.ContentTypeConnectionID, + Epoch: p.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), + ConnectionID: c.state.remoteConnectionID, + SequenceNumber: p.record.Header.SequenceNumber, + } + rawPacket, err = cidHeader.Marshal() + if err != nil { + return nil, err + } + p.record.Header = *cidHeader + rawPacket = append(rawPacket, rawInner...) + } else { + var err error + rawPacket, err = p.record.Marshal() + if err != nil { + return nil, err + } + } + + if p.shouldEncrypt { + var err error + rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) + if err != nil { + return nil, err + } + } + + return rawPacket, nil +} + +func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) { + rawPackets := make([][]byte, 0) + + handshakeFragments, err := c.fragmentHandshake(h) + if err != nil { + return nil, err + } + epoch := p.record.Header.Epoch + for len(c.state.localSequenceNumber) <= int(epoch) { + c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) + } + + for _, handshakeFragment := range handshakeFragments { + seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 + if seq > recordlayer.MaxSequenceNumber { + return nil, errSequenceNumberOverflow + } + + var rawPacket []byte + if p.shouldWrapCID { + inner := &recordlayer.InnerPlaintext{ + Content: handshakeFragment, + RealType: protocol.ContentTypeHandshake, + Zeros: c.paddingLengthGenerator(uint(len(handshakeFragment))), + } + rawInner, err := inner.Marshal() //nolint:govet + if err != nil { + return nil, err + } + cidHeader := &recordlayer.Header{ + Version: p.record.Header.Version, + ContentType: protocol.ContentTypeConnectionID, + Epoch: p.record.Header.Epoch, + ContentLen: uint16(len(rawInner)), + ConnectionID: c.state.remoteConnectionID, + SequenceNumber: p.record.Header.SequenceNumber, + } + rawPacket, err = cidHeader.Marshal() + if err != nil { + return nil, err + } + p.record.Header = *cidHeader + rawPacket = append(rawPacket, rawInner...) + } else { + recordlayerHeader := &recordlayer.Header{ + Version: p.record.Header.Version, + ContentType: p.record.Header.ContentType, + ContentLen: uint16(len(handshakeFragment)), + Epoch: p.record.Header.Epoch, + SequenceNumber: seq, + } + + rawPacket, err = recordlayerHeader.Marshal() + if err != nil { + return nil, err + } + + p.record.Header = *recordlayerHeader + rawPacket = append(rawPacket, handshakeFragment...) + } + + if p.shouldEncrypt { + var err error + rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) + if err != nil { + return nil, err + } + } + + rawPackets = append(rawPackets, rawPacket) + } + + return rawPackets, nil +} + +func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { + content, err := h.Message.Marshal() + if err != nil { + return nil, err + } + + fragmentedHandshakes := make([][]byte, 0) + + contentFragments := splitBytes(content, c.maximumTransmissionUnit) + if len(contentFragments) == 0 { + contentFragments = [][]byte{ + {}, + } + } + + offset := 0 + for _, contentFragment := range contentFragments { + contentFragmentLen := len(contentFragment) + + headerFragment := &handshake.Header{ + Type: h.Header.Type, + Length: h.Header.Length, + MessageSequence: h.Header.MessageSequence, + FragmentOffset: uint32(offset), + FragmentLength: uint32(contentFragmentLen), + } + + offset += contentFragmentLen + + fragmentedHandshake, err := headerFragment.Marshal() + if err != nil { + return nil, err + } + + fragmentedHandshake = append(fragmentedHandshake, contentFragment...) + fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake) + } + + return fragmentedHandshakes, nil +} + +var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals + New: func() interface{} { + b := make([]byte, inboundBufferSize) + return &b + }, +} + +func (c *Conn) readAndBuffer(ctx context.Context) error { + bufptr, ok := poolReadBuffer.Get().(*[]byte) + if !ok { + return errFailedToAccessPoolReadBuffer + } + defer poolReadBuffer.Put(bufptr) + + b := *bufptr + i, rAddr, err := c.nextConn.ReadFromContext(ctx, b) + if err != nil { + return netError(err) + } + + pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.localConnectionID)) + if err != nil { + return err + } + + var hasHandshake bool + for _, p := range pkts { + hs, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) + if alert != nil { + if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err == nil { + err = alertErr + } + } + } + + var e *alertError + if errors.As(err, &e) && e.IsFatalOrCloseNotify() { + return e + } + if err != nil { + return err + } + if hs { + hasHandshake = true + } + } + if hasHandshake { + done := make(chan struct{}) + select { + case c.handshakeRecv <- done: + // If the other party may retransmit the flight, + // we should respond even if it not a new message. + <-done + case <-c.fsm.Done(): + } + } + return nil +} + +func (c *Conn) handleQueuedPackets(ctx context.Context) error { + pkts := c.encryptedPackets + c.encryptedPackets = nil + + for _, p := range pkts { + _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue + if alert != nil { + if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err == nil { + err = alertErr + } + } + } + var e *alertError + if errors.As(err, &e) && e.IsFatalOrCloseNotify() { + return e + } + if err != nil { + return err + } + } + return nil +} + +func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit + h := &recordlayer.Header{} + // Set connection ID size so that records of content type tls12_cid will + // be parsed correctly. + if len(c.state.localConnectionID) > 0 { + h.ConnectionID = make([]byte, len(c.state.localConnectionID)) + } + if err := h.Unmarshal(buf); err != nil { + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + c.log.Debugf("discarded broken packet: %v", err) + return false, nil, nil + } + + // Validate epoch + remoteEpoch := c.state.getRemoteEpoch() + if h.Epoch > remoteEpoch { + if h.Epoch > remoteEpoch+1 { + c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", + h.Epoch, h.SequenceNumber, + ) + return false, nil, nil + } + if enqueue { + c.log.Debug("received packet of next epoch, queuing packet") + c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf}) + } + return false, nil, nil + } + + // Anti-replay protection + for len(c.state.replayDetector) <= int(h.Epoch) { + c.state.replayDetector = append(c.state.replayDetector, + replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), + ) + } + markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber) + if !ok { + c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", + h.Epoch, h.SequenceNumber, + ) + return false, nil, nil + } + + // originalCID indicates whether the original record had content type + // Connection ID. + originalCID := false + + // Decrypt + if h.Epoch != 0 { + if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { + if enqueue { + c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf}) + c.log.Debug("handshake not finished, queuing packet") + } + return false, nil, nil + } + + // If a connection identifier had been negotiated and encryption is + // enabled, the connection identifier MUST be sent. + if len(c.state.localConnectionID) > 0 && h.ContentType != protocol.ContentTypeConnectionID { + c.log.Debug("discarded packet missing connection ID after value negotiated") + return false, nil, nil + } + + var err error + var hdr recordlayer.Header + if h.ContentType == protocol.ContentTypeConnectionID { + hdr.ConnectionID = make([]byte, len(c.state.localConnectionID)) + } + buf, err = c.state.cipherSuite.Decrypt(hdr, buf) + if err != nil { + c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) + return false, nil, nil + } + // If this is a connection ID record, make it look like a normal record for + // further processing. + if h.ContentType == protocol.ContentTypeConnectionID { + originalCID = true + ip := &recordlayer.InnerPlaintext{} + if err := ip.Unmarshal(buf[h.Size():]); err != nil { //nolint:govet + c.log.Debugf("unpacking inner plaintext failed: %s", err) + return false, nil, nil + } + unpacked := &recordlayer.Header{ + ContentType: ip.RealType, + ContentLen: uint16(len(ip.Content)), + Version: h.Version, + Epoch: h.Epoch, + SequenceNumber: h.SequenceNumber, + } + buf, err = unpacked.Marshal() + if err != nil { + c.log.Debugf("converting CID record to inner plaintext failed: %s", err) + return false, nil, nil + } + buf = append(buf, ip.Content...) + } + + // If connection ID does not match discard the packet. + if !bytes.Equal(c.state.localConnectionID, h.ConnectionID) { + c.log.Debug("unexpected connection ID") + return false, nil, nil + } + } + + isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...)) + if err != nil { + // Decode error must be silently discarded + // [RFC6347 Section-4.1.2.7] + c.log.Debugf("defragment failed: %s", err) + return false, nil, nil + } else if isHandshake { + markPacketAsValid() + for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { + header := &handshake.Header{} + if err := header.Unmarshal(out); err != nil { + c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) + continue + } + c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) + } + + return true, nil, nil + } + + r := &recordlayer.RecordLayer{} + if err := r.Unmarshal(buf); err != nil { + return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err + } + + isLatestSeqNum := false + switch content := r.Content.(type) { + case *alert.Alert: + c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String()) + var a *alert.Alert + if content.Description == alert.CloseNotify { + // Respond with a close_notify [RFC5246 Section 7.2.1] + a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} + } + _ = markPacketAsValid() + return false, a, &alertError{content} + case *protocol.ChangeCipherSpec: + if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { + if enqueue { + c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf}) + c.log.Debugf("CipherSuite not initialized, queuing packet") + } + return false, nil, nil + } + + newRemoteEpoch := h.Epoch + 1 + c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) + + if c.state.getRemoteEpoch()+1 == newRemoteEpoch { + c.setRemoteEpoch(newRemoteEpoch) + isLatestSeqNum = markPacketAsValid() + } + case *protocol.ApplicationData: + if h.Epoch == 0 { + return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero + } + + isLatestSeqNum = markPacketAsValid() + + select { + case c.decrypted <- content.Data: + case <-c.closed.Done(): + case <-ctx.Done(): + } + + default: + return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) + } + + // Any valid connection ID record is a candidate for updating the remote + // address if it is the latest record received. + // https://datatracker.ietf.org/doc/html/rfc9146#peer-address-update + if originalCID && isLatestSeqNum { + if rAddr != c.RemoteAddr() { + c.lock.Lock() + c.rAddr = rAddr + c.lock.Unlock() + } + } + + return false, nil, nil +} + +func (c *Conn) recvHandshake() <-chan chan struct{} { + return c.handshakeRecv +} + +func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { + if level == alert.Fatal && len(c.state.SessionID) > 0 { + // According to the RFC, we need to delete the stored session. + // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 + if ss := c.fsm.cfg.sessionStore; ss != nil { + c.log.Tracef("clean invalid session: %s", c.state.SessionID) + if err := ss.Del(c.sessionKey()); err != nil { + return err + } + } + } + return c.writePackets(ctx, []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Epoch: c.state.getLocalEpoch(), + Version: protocol.Version1_2, + }, + Content: &alert.Alert{ + Level: level, + Description: desc, + }, + }, + shouldWrapCID: len(c.state.remoteConnectionID) > 0, + shouldEncrypt: c.isHandshakeCompletedSuccessfully(), + }, + }) +} + +func (c *Conn) setHandshakeCompletedSuccessfully() { + c.handshakeCompletedSuccessfully.Store(struct{ bool }{true}) +} + +func (c *Conn) isHandshakeCompletedSuccessfully() bool { + boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool }) + return boolean.bool +} + +func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit + c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) + + done := make(chan struct{}) + ctxRead, cancelRead := context.WithCancel(context.Background()) + c.cancelHandshakeReader = cancelRead + cfg.onFlightState = func(f flightVal, s handshakeState) { + if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() { + c.setHandshakeCompletedSuccessfully() + close(done) + } + } + + ctxHs, cancel := context.WithCancel(context.Background()) + c.cancelHandshaker = cancel + + firstErr := make(chan error, 1) + + c.handshakeLoopsFinished.Add(2) + + // Handshake routine should be live until close. + // The other party may request retransmission of the last flight to cope with packet drop. + go func() { + defer c.handshakeLoopsFinished.Done() + err := c.fsm.Run(ctxHs, c, initialState) + if !errors.Is(err, context.Canceled) { + select { + case firstErr <- err: + default: + } + } + }() + go func() { + defer func() { + // Escaping read loop. + // It's safe to close decrypted channnel now. + close(c.decrypted) + + // Force stop handshaker when the underlying connection is closed. + cancel() + }() + defer c.handshakeLoopsFinished.Done() + for { + if err := c.readAndBuffer(ctxRead); err != nil { + var e *alertError + if errors.As(err, &e) { + if !e.IsFatalOrCloseNotify() { + if c.isHandshakeCompletedSuccessfully() { + // Pass the error to Read() + select { + case c.decrypted <- err: + case <-c.closed.Done(): + case <-ctxRead.Done(): + } + } + continue // non-fatal alert must not stop read loop + } + } else { + switch { + case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): + default: + if c.isHandshakeCompletedSuccessfully() { + // Keep read loop and pass the read error to Read() + select { + case c.decrypted <- err: + case <-c.closed.Done(): + case <-ctxRead.Done(): + } + continue // non-fatal alert must not stop read loop + } + } + } + + select { + case firstErr <- err: + default: + } + + if e != nil { + if e.IsFatalOrCloseNotify() { + _ = c.close(false) //nolint:contextcheck + } + } + if !c.isConnectionClosed() && errors.Is(err, context.Canceled) { + c.log.Trace("handshake timeouts - closing underline connection") + _ = c.close(false) //nolint:contextcheck + } + return + } + } + }() + + select { + case err := <-firstErr: + cancelRead() + cancel() + c.handshakeLoopsFinished.Wait() + return c.translateHandshakeCtxError(err) + case <-ctx.Done(): + cancelRead() + cancel() + c.handshakeLoopsFinished.Wait() + return c.translateHandshakeCtxError(ctx.Err()) + case <-done: + return nil + } +} + +func (c *Conn) translateHandshakeCtxError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { + return nil + } + return &HandshakeError{Err: err} +} + +func (c *Conn) close(byUser bool) error { + c.cancelHandshaker() + c.cancelHandshakeReader() + + if c.isHandshakeCompletedSuccessfully() && byUser { + // Discard error from notify() to return non-error on the first user call of Close() + // even if the underlying connection is already closed. + _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify) + } + + c.closeLock.Lock() + // Don't return ErrConnClosed at the first time of the call from user. + closedByUser := c.connectionClosedByUser + if byUser { + c.connectionClosedByUser = true + } + isClosed := c.isConnectionClosed() + c.closed.Close() + c.closeLock.Unlock() + + if closedByUser { + return ErrConnClosed + } + + if isClosed { + return nil + } + + return c.nextConn.Close() +} + +func (c *Conn) isConnectionClosed() bool { + select { + case <-c.closed.Done(): + return true + default: + return false + } +} + +func (c *Conn) setLocalEpoch(epoch uint16) { + c.state.localEpoch.Store(epoch) +} + +func (c *Conn) setRemoteEpoch(epoch uint16) { + c.state.remoteEpoch.Store(epoch) +} + +// LocalAddr implements net.Conn.LocalAddr +func (c *Conn) LocalAddr() net.Addr { + return c.nextConn.LocalAddr() +} + +// RemoteAddr implements net.Conn.RemoteAddr +func (c *Conn) RemoteAddr() net.Addr { + c.lock.RLock() + defer c.lock.RUnlock() + return c.rAddr +} + +func (c *Conn) sessionKey() []byte { + if c.state.isClient { + // As ServerName can be like 0.example.com, it's better to add + // delimiter character which is not allowed to be in + // neither address or domain name. + return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) + } + return c.state.SessionID +} + +// SetDeadline implements net.Conn.SetDeadline +func (c *Conn) SetDeadline(t time.Time) error { + c.readDeadline.Set(t) + return c.SetWriteDeadline(t) +} + +// SetReadDeadline implements net.Conn.SetReadDeadline +func (c *Conn) SetReadDeadline(t time.Time) error { + c.readDeadline.Set(t) + // Read deadline is fully managed by this layer. + // Don't set read deadline to underlying connection. + return nil +} + +// SetWriteDeadline implements net.Conn.SetWriteDeadline +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + // Write deadline is also fully managed by this layer. + return nil +} diff --git a/vendor/github.com/pion/dtls/v2/connection_id.go b/vendor/github.com/pion/dtls/v2/connection_id.go new file mode 100644 index 0000000000..b2fbbd7a87 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/connection_id.go @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "crypto/rand" + + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// RandomCIDGenerator is a random Connection ID generator where CID is the +// specified size. Specifying a size of 0 will indicate to peers that sending a +// Connection ID is not necessary. +func RandomCIDGenerator(size int) func() []byte { + return func() []byte { + cid := make([]byte, size) + if _, err := rand.Read(cid); err != nil { + panic(err) //nolint -- nonrecoverable + } + return cid + } +} + +// OnlySendCIDGenerator enables sending Connection IDs negotiated with a peer, +// but indicates to the peer that sending Connection IDs in return is not +// necessary. +func OnlySendCIDGenerator() func() []byte { + return func() []byte { + return nil + } +} + +// cidDatagramRouter extracts connection IDs from incoming datagram payloads and +// uses them to route to the proper connection. +// NOTE: properly routing datagrams based on connection IDs requires using +// constant size connection IDs. +func cidDatagramRouter(size int) func([]byte) (string, bool) { + return func(packet []byte) (string, bool) { + pkts, err := recordlayer.ContentAwareUnpackDatagram(packet, size) + if err != nil || len(pkts) < 1 { + return "", false + } + for _, pkt := range pkts { + h := &recordlayer.Header{ + ConnectionID: make([]byte, size), + } + if err := h.Unmarshal(pkt); err != nil { + continue + } + if h.ContentType != protocol.ContentTypeConnectionID { + continue + } + return string(h.ConnectionID), true + } + return "", false + } +} + +// cidConnIdentifier extracts connection IDs from outgoing ServerHello records +// and associates them with the associated connection. +// NOTE: a ServerHello should always be the first record in a datagram if +// multiple are present, so we avoid iterating through all packets if the first +// is not a ServerHello. +func cidConnIdentifier() func([]byte) (string, bool) { + return func(packet []byte) (string, bool) { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return "", false + } + var h recordlayer.Header + if hErr := h.Unmarshal(pkts[0]); hErr != nil { + return "", false + } + if h.ContentType != protocol.ContentTypeHandshake { + return "", false + } + var hh handshake.Header + var sh handshake.MessageServerHello + for _, pkt := range pkts { + if hhErr := hh.Unmarshal(pkt[recordlayer.FixedHeaderSize:]); hhErr != nil { + continue + } + if err = sh.Unmarshal(pkt[recordlayer.FixedHeaderSize+handshake.HeaderLength:]); err == nil { + break + } + } + if err != nil { + return "", false + } + for _, ext := range sh.Extensions { + if e, ok := ext.(*extension.ConnectionID); ok { + return string(e.CID), true + } + } + return "", false + } +} diff --git a/vendor/github.com/pion/dtls/v2/flight0handler.go b/vendor/github.com/pion/dtls/v2/flight0handler.go new file mode 100644 index 0000000000..648c52883a --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/flight0handler.go @@ -0,0 +1,156 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto/rand" + + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v2/pkg/protocol/handshake" +) + +func flight0Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { + seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + // Connection Identifiers must be negotiated afresh on session resumption. + // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension + state.localConnectionID = nil + state.remoteConnectionID = nil + + state.handshakeRecvSequence = seq + + var clientHello *handshake.MessageClientHello + + // Validate type + if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if !clientHello.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + + state.remoteRandom = clientHello.Random + + cipherSuites := []CipherSuite{} + for _, id := range clientHello.CipherSuiteIDs { + if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil { + cipherSuites = append(cipherSuites, c) + } + } + + if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection + } + + for _, val := range clientHello.Extensions { + switch e := val.(type) { + case *extension.SupportedEllipticCurves: + if len(e.EllipticCurves) == 0 { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves + } + state.namedCurve = e.EllipticCurves[0] + case *extension.UseSRTP: + profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) + if !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile + } + state.srtpProtectionProfile = profile + case *extension.UseExtendedMasterSecret: + if cfg.extendedMasterSecret != DisableExtendedMasterSecret { + state.extendedMasterSecret = true + } + case *extension.ServerName: + state.serverName = e.ServerName // remote server name + case *extension.ALPN: + state.peerSupportedProtocols = e.ProtocolNameList + case *extension.ConnectionID: + // Only set connection ID to be sent if server supports connection + // IDs. + if cfg.connectionIDGenerator != nil { + state.remoteConnectionID = e.CID + } + } + } + + // If the client doesn't support connection IDs, the server should not + // expect one to be sent. + if state.remoteConnectionID == nil { + state.localConnectionID = nil + } + + if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS + } + + if state.localKeypair == nil { + var err error + state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err + } + } + + nextFlight := flight2 + + if cfg.insecureSkipHelloVerify { + nextFlight = flight4 + } + + return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight) +} + +func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) { + if len(sessionID) > 0 && cfg.sessionStore != nil { + if s, err := cfg.sessionStore.Get(sessionID); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } else if s.ID != nil { + cfg.log.Tracef("[handshake] resume session: %x", sessionID) + + state.SessionID = sessionID + state.masterSecret = s.Secret + + if err := state.initCipherSuite(); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + clientRandom := state.localRandom.MarshalFixed() + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + + return flight4b, nil, nil + } + } + return next, nil, nil +} + +func flight0Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { + // Initialize + if !cfg.insecureSkipHelloVerify { + state.cookie = make([]byte, cookieLength) + if _, err := rand.Read(state.cookie); err != nil { + return nil, nil, err + } + } + + var zeroEpoch uint16 + state.localEpoch.Store(zeroEpoch) + state.remoteEpoch.Store(zeroEpoch) + state.namedCurve = defaultNamedCurve + + if err := state.localRandom.Populate(); err != nil { + return nil, nil, err + } + + return nil, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v2/flight1handler.go b/vendor/github.com/pion/dtls/v2/flight1handler.go new file mode 100644 index 0000000000..48bc88213e --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/flight1handler.go @@ -0,0 +1,156 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { + // HelloVerifyRequest can be skipped by the server, + // so allow ServerHello during flight1 also + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + if _, ok := msgs[handshake.TypeServerHello]; ok { + // Flight1 and flight2 were skipped. + // Parse as flight3. + return flight3Parse(ctx, c, state, cache, cfg) + } + + if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok { + // DTLS 1.2 clients must not assume that the server will use the protocol version + // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 + if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + state.cookie = append([]byte{}, h.Cookie...) + state.handshakeRecvSequence = seq + return flight3, nil, nil + } + + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil +} + +func flight1Generate(c flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { + var zeroEpoch uint16 + state.localEpoch.Store(zeroEpoch) + state.remoteEpoch.Store(zeroEpoch) + state.namedCurve = defaultNamedCurve + state.cookie = nil + + if err := state.localRandom.Populate(); err != nil { + return nil, nil, err + } + + extensions := []extension.Extension{ + &extension.SupportedSignatureAlgorithms{ + SignatureHashAlgorithms: cfg.localSignatureSchemes, + }, + &extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }, + } + + var setEllipticCurveCryptographyClientHelloExtensions bool + for _, c := range cfg.localCipherSuites { + if c.ECC() { + setEllipticCurveCryptographyClientHelloExtensions = true + break + } + } + + if setEllipticCurveCryptographyClientHelloExtensions { + extensions = append(extensions, []extension.Extension{ + &extension.SupportedEllipticCurves{ + EllipticCurves: cfg.ellipticCurves, + }, + &extension.SupportedPointFormats{ + PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, + }, + }...) + } + + if len(cfg.localSRTPProtectionProfiles) > 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: cfg.localSRTPProtectionProfiles, + }) + } + + if cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + + if len(cfg.serverName) > 0 { + extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) + } + + if len(cfg.supportedProtocols) > 0 { + extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) + } + + if cfg.sessionStore != nil { + cfg.log.Tracef("[handshake] try to resume session") + if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } else if s.ID != nil { + cfg.log.Tracef("[handshake] get saved session: %x", s.ID) + + state.SessionID = s.ID + state.masterSecret = s.Secret + } + } + + // If we have a connection ID generator, use it. The CID may be zero length, + // in which case we are just requesting that the server send us a CID to + // use. + if cfg.connectionIDGenerator != nil { + state.localConnectionID = cfg.connectionIDGenerator() + // The presence of a generator indicates support for connection IDs. We + // use the presence of a non-nil local CID in flight 3 to determine + // whether we send a CID in the second ClientHello, so we convert any + // nil CID returned by a generator to []byte{}. + if state.localConnectionID == nil { + state.localConnectionID = []byte{} + } + extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) + } + + return []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + }, + }, + }, + }, + }, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v2/flight3handler.go b/vendor/github.com/pion/dtls/v2/flight3handler.go new file mode 100644 index 0000000000..b17a7986c3 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/flight3handler.go @@ -0,0 +1,309 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + + "github.com/pion/dtls/v2/internal/ciphersuite/types" + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit + // Clients may receive multiple HelloVerifyRequest messages with different cookies. + // Clients SHOULD handle this by sending a new ClientHello with a cookie in response + // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, + ) + if ok { + if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk { + // DTLS 1.2 clients must not assume that the server will use the protocol version + // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 + if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + state.cookie = append([]byte{}, h.Cookie...) + state.handshakeRecvSequence = seq + return flight3, nil, nil + } + } + + _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + ) + if !ok { + // Don't have enough messages. Keep reading + return 0, nil, nil + } + + if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { + if !h.Version.Equal(protocol.Version1_2) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion + } + for _, v := range h.Extensions { + switch e := v.(type) { + case *extension.UseSRTP: + profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) + if !found { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile + } + state.srtpProtectionProfile = profile + case *extension.UseExtendedMasterSecret: + if cfg.extendedMasterSecret != DisableExtendedMasterSecret { + state.extendedMasterSecret = true + } + case *extension.ALPN: + if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error? + } + state.NegotiatedProtocol = e.ProtocolNameList[0] + case *extension.ConnectionID: + // Only set connection ID to be sent if client supports connection + // IDs. + if cfg.connectionIDGenerator != nil { + state.remoteConnectionID = e.CID + } + } + } + // If the server doesn't support connection IDs, the client should not + // expect one to be sent. + if state.remoteConnectionID == nil { + state.localConnectionID = nil + } + + if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS + } + if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension + } + + remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites) + if remoteCipherSuite == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection + } + + selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites) + if !found { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite + } + + state.cipherSuite = selectedCipherSuite + state.remoteRandom = h.Random + cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) + + if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) { + return handleResumption(ctx, c, state, cache, cfg) + } + + if len(state.SessionID) > 0 { + cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID) + if err := cfg.sessionStore.Del(state.SessionID); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + if cfg.sessionStore == nil { + state.SessionID = []byte{} + } else { + state.SessionID = h.SessionID + } + + state.masterSecret = []byte{} + } + + if cfg.localPSKCallback != nil { + seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + ) + } else { + seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + ) + } + if !ok { + // Don't have enough messages. Keep reading + return 0, nil, nil + } + state.handshakeRecvSequence = seq + + if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok { + state.PeerCertificates = h.Certificate + } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate + } + + if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { + alertPtr, err := handleServerKeyExchange(c, state, cfg, h) + if err != nil { + return 0, alertPtr, err + } + } + + if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { + state.remoteRequestedCertificate = true + } + + return flight5, nil, nil +} + +func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { + if err := state.initCipherSuite(); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + // Now, encrypted packets can be handled + if err := c.handleQueuedPackets(ctx); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + var finished *handshake.MessageFinished + if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + ) + + expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if !bytes.Equal(expectedVerifyData, finished.VerifyData) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch + } + + clientRandom := state.localRandom.MarshalFixed() + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + + return flight5b, nil, nil +} + +func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) { + var err error + if state.cipherSuite == nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite + } + if cfg.localPSKCallback != nil { + var psk []byte + if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.IdentityHint = h.IdentityHint + switch state.cipherSuite.KeyExchangeAlgorithm() { + case types.KeyExchangeAlgorithmPsk: + state.preMasterSecret = prf.PSKPreMasterSecret(psk) + case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): + if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + default: + return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite + } + } else { + if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + return nil, nil //nolint:nilnil +} + +func flight3Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { + extensions := []extension.Extension{ + &extension.SupportedSignatureAlgorithms{ + SignatureHashAlgorithms: cfg.localSignatureSchemes, + }, + &extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }, + } + if state.namedCurve != 0 { + extensions = append(extensions, []extension.Extension{ + &extension.SupportedEllipticCurves{ + EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, + }, + &extension.SupportedPointFormats{ + PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, + }, + }...) + } + + if len(cfg.localSRTPProtectionProfiles) > 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: cfg.localSRTPProtectionProfiles, + }) + } + + if cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + + if len(cfg.serverName) > 0 { + extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) + } + + if len(cfg.supportedProtocols) > 0 { + extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) + } + + // If we sent a connection ID on the first ClientHello, send it on the + // second. + if state.localConnectionID != nil { + extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) + } + + return []*packet{ + { + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageClientHello{ + Version: protocol.Version1_2, + SessionID: state.SessionID, + Cookie: state.cookie, + Random: state.localRandom, + CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), + CompressionMethods: defaultCompressionMethods(), + Extensions: extensions, + }, + }, + }, + }, + }, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v2/flight4handler.go b/vendor/github.com/pion/dtls/v2/flight4handler.go new file mode 100644 index 0000000000..52568139f7 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/flight4handler.go @@ -0,0 +1,411 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto/rand" + "crypto/x509" + + "github.com/pion/dtls/v2/internal/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/extension" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit + seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + // Validate type + var clientKeyExchange *handshake.MessageClientKeyExchange + if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert { + state.PeerCertificates = h.Certificate + // If the client offer its certificate, just disable session resumption. + // Otherwise, we have to store the certificate identitfication and expire time. + // And we have to check whether this certificate expired, revoked or changed. + // + // https://curl.se/docs/CVE-2016-5419.html + state.SessionID = nil + } + + if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify { + if state.PeerCertificates == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate + } + + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + ) + + // Verify that the pair of hash algorithm and signiture is listed. + var validSignatureScheme bool + for _, ss := range cfg.localSignatureSchemes { + if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { + validSignatureScheme = true + break + } + } + if !validSignatureScheme { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes + } + + if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + var chains [][]*x509.Certificate + var err error + var verified bool + if cfg.clientAuth >= VerifyClientCertIfGiven { + if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + verified = true + } + if cfg.verifyPeerCertificate != nil { + if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + state.peerCertificatesVerified = verified + } else if state.PeerCertificates != nil { + // A certificate was received, but we haven't seen a CertificateVerify + // keep reading until we receive one + return 0, nil, nil + } + + if !state.cipherSuite.IsInitialized() { + serverRandom := state.localRandom.MarshalFixed() + clientRandom := state.remoteRandom.MarshalFixed() + + var err error + var preMasterSecret []byte + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey { + var psk []byte + if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.IdentityHint = clientKeyExchange.IdentityHint + switch state.cipherSuite.KeyExchangeAlgorithm() { + case CipherSuiteKeyExchangeAlgorithmPsk: + preMasterSecret = prf.PSKPreMasterSecret(psk) + case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe): + if preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + default: + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite + } + } else { + preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err + } + } + + if state.extendedMasterSecret { + var sessionHash []byte + sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } else { + state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + } + + if len(state.SessionID) > 0 { + s := Session{ + ID: state.SessionID, + Secret: state.masterSecret, + } + cfg.log.Tracef("[handshake] save new session: %x", s.ID) + if err := cfg.sessionStore.Set(state.SessionID, s); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + // Now, encrypted packets can be handled + if err := c.handleQueuedPackets(ctx); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + seq, msgs, ok = cache.fullPullMap(seq, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + state.handshakeRecvSequence = seq + + if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { + if cfg.verifyConnection != nil { + if err := cfg.verifyConnection(state.clone()); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + return flight6, nil, nil + } + + switch cfg.clientAuth { + case RequireAnyClientCert: + if state.PeerCertificates == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired + } + case VerifyClientCertIfGiven: + if state.PeerCertificates != nil && !state.peerCertificatesVerified { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified + } + case RequireAndVerifyClientCert: + if state.PeerCertificates == nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired + } + if !state.peerCertificatesVerified { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified + } + case NoClientCert, RequestClientCert: + // go to flight6 + } + if cfg.verifyConnection != nil { + if err := cfg.verifyConnection(state.clone()); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + + return flight6, nil, nil +} + +func flight4Generate(_ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit + extensions := []extension.Extension{&extension.RenegotiationInfo{ + RenegotiatedConnection: 0, + }} + if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || + cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { + extensions = append(extensions, &extension.UseExtendedMasterSecret{ + Supported: true, + }) + } + if state.srtpProtectionProfile != 0 { + extensions = append(extensions, &extension.UseSRTP{ + ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile}, + }) + } + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { + extensions = append(extensions, &extension.SupportedPointFormats{ + PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, + }) + } + + selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err + } + if selectedProto != "" { + extensions = append(extensions, &extension.ALPN{ + ProtocolNameList: []string{selectedProto}, + }) + state.NegotiatedProtocol = selectedProto + } + + // If we have a connection ID generator, we are willing to use connection + // IDs. We already know whether the client supports connection IDs from + // parsing the ClientHello, so avoid setting local connection ID if the + // client won't send it. + if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil { + state.localConnectionID = cfg.connectionIDGenerator() + extensions = append(extensions, &extension.ConnectionID{CID: state.localConnectionID}) + } + + var pkts []*packet + cipherSuiteID := uint16(state.cipherSuite.ID()) + + if cfg.sessionStore != nil { + state.SessionID = make([]byte, sessionLength) + if _, err := rand.Read(state.SessionID); err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageServerHello{ + Version: protocol.Version1_2, + Random: state.localRandom, + SessionID: state.SessionID, + CipherSuiteID: &cipherSuiteID, + CompressionMethod: defaultCompressionMethods()[0], + Extensions: extensions, + }, + }, + }, + }) + + switch { + case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: + certificate, err := cfg.getCertificate(&ClientHelloInfo{ + ServerName: state.serverName, + CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()}, + }) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err + } + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageCertificate{ + Certificate: certificate.Certificate, + }, + }, + }, + }) + + serverRandom := state.localRandom.MarshalFixed() + clientRandom := state.remoteRandom.MarshalFixed() + + // Find compatible signature scheme + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err + } + + signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.localKeySignature = signature + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageServerKeyExchange{ + EllipticCurveType: elliptic.CurveTypeNamedCurve, + NamedCurve: state.namedCurve, + PublicKey: state.localKeypair.PublicKey, + HashAlgorithm: signatureHashAlgo.Hash, + SignatureAlgorithm: signatureHashAlgo.Signature, + Signature: state.localKeySignature, + }, + }, + }, + }) + + if cfg.clientAuth > NoClientCert { + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + var certificateAuthorities [][]byte + if cfg.clientCAs != nil { + // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool and it's ok if certificate authorities is empty. + certificateAuthorities = cfg.clientCAs.Subjects() + } + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageCertificateRequest{ + CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, + SignatureHashAlgorithms: cfg.localSignatureSchemes, + CertificateAuthoritiesNames: certificateAuthorities, + }, + }, + }, + }) + } + case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): + // To help the client in selecting which identity to use, the server + // can provide a "PSK identity hint" in the ServerKeyExchange message. + // If no hint is provided and cipher suite doesn't use elliptic curve, + // the ServerKeyExchange message is omitted. + // + // https://tools.ietf.org/html/rfc4279#section-2 + srvExchange := &handshake.MessageServerKeyExchange{ + IdentityHint: cfg.localPSKIdentityHint, + } + if state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe) { + srvExchange.EllipticCurveType = elliptic.CurveTypeNamedCurve + srvExchange.NamedCurve = state.namedCurve + srvExchange.PublicKey = state.localKeypair.PublicKey + } + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: srvExchange, + }, + }, + }) + } + + pkts = append(pkts, &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageServerHelloDone{}, + }, + }, + }) + + return pkts, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v2/flight5handler.go b/vendor/github.com/pion/dtls/v2/flight5handler.go new file mode 100644 index 0000000000..e1cca6238a --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/flight5handler.go @@ -0,0 +1,358 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "context" + "crypto" + "crypto/x509" + + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +func flight5Parse(_ context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + var finished *handshake.MessageFinished + if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + + expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + if !bytes.Equal(expectedVerifyData, finished.VerifyData) { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch + } + + if len(state.SessionID) > 0 { + s := Session{ + ID: state.SessionID, + Secret: state.masterSecret, + } + cfg.log.Tracef("[handshake] save new session: %x", s.ID) + if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + return flight5, nil, nil +} + +func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit + var privateKey crypto.PrivateKey + var pkts []*packet + if state.remoteRequestedCertificate { + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired + } + reqInfo := CertificateRequestInfo{} + if r, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { + reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames + } else { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired + } + certificate, err := cfg.getClientCertificate(&reqInfo) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err + } + if certificate == nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain + } + if certificate.Certificate != nil { + privateKey = certificate.PrivateKey + } + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageCertificate{ + Certificate: certificate.Certificate, + }, + }, + }, + }) + } + + clientKeyExchange := &handshake.MessageClientKeyExchange{} + if cfg.localPSKCallback == nil { + clientKeyExchange.PublicKey = state.localKeypair.PublicKey + } else { + clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint + } + if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 { + clientKeyExchange.PublicKey = state.localKeypair.PublicKey + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: clientKeyExchange, + }, + }, + }) + + serverKeyExchangeData := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + ) + + serverKeyExchange := &handshake.MessageServerKeyExchange{} + + // handshakeMessageServerKeyExchange is optional for PSK + if len(serverKeyExchangeData) == 0 { + alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{}) + if err != nil { + return nil, alertPtr, err + } + } else { + rawHandshake := &handshake.Handshake{ + KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(), + } + err := rawHandshake.Unmarshal(serverKeyExchangeData) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err + } + + switch h := rawHandshake.Message.(type) { + case *handshake.MessageServerKeyExchange: + serverKeyExchange = h + default: + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType + } + } + + // Append not-yet-sent packets + merged := []byte{} + seqPred := uint16(state.handshakeSendSequence) + for _, p := range pkts { + h, ok := p.record.Content.(*handshake.Handshake) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType + } + h.Header.MessageSequence = seqPred + seqPred++ + raw, err := h.Marshal() + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + merged = append(merged, raw...) + } + + if alertPtr, err := initializeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil { + return nil, alertPtr, err + } + + // If the client has sent a certificate with signing ability, a digitally-signed + // CertificateVerify message is sent to explicitly verify possession of the + // private key in the certificate. + if state.remoteRequestedCertificate && privateKey != nil { + plainText := append(cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + ), merged...) + + // Find compatible signature scheme + signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err + } + + certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + state.localCertificatesVerify = certVerify + + p := &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageCertificateVerify{ + HashAlgorithm: signatureHashAlgo.Hash, + SignatureAlgorithm: signatureHashAlgo.Signature, + Signature: state.localCertificatesVerify, + }, + }, + }, + } + pkts = append(pkts, p) + + h, ok := p.record.Content.(*handshake.Handshake) + if !ok { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType + } + h.Header.MessageSequence = seqPred + // seqPred++ // this is the last use of seqPred + raw, err := h.Marshal() + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + merged = append(merged, raw...) + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &protocol.ChangeCipherSpec{}, + }, + }) + + if len(state.localVerifyData) == 0 { + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + + var err error + state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc()) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + Epoch: 1, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageFinished{ + VerifyData: state.localVerifyData, + }, + }, + }, + shouldWrapCID: len(state.remoteConnectionID) > 0, + shouldEncrypt: true, + resetLocalSequenceNumber: true, + }) + + return pkts, nil, nil +} + +func initializeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit + if state.cipherSuite.IsInitialized() { + return nil, nil //nolint + } + + clientRandom := state.localRandom.MarshalFixed() + serverRandom := state.remoteRandom.MarshalFixed() + + var err error + + if state.extendedMasterSecret { + var sessionHash []byte + sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err + } + } else { + state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) + if err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { + // Verify that the pair of hash algorithm and signiture is listed. + var validSignatureScheme bool + for _, ss := range cfg.localSignatureSchemes { + if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { + validSignatureScheme = true + break + } + } + if !validSignatureScheme { + return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes + } + + expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve) + if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + var chains [][]*x509.Certificate + if !cfg.insecureSkipVerify { + if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + if cfg.verifyPeerCertificate != nil { + if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + } + if cfg.verifyConnection != nil { + if err = cfg.verifyConnection(state.clone()); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err + } + } + + if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil { + return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + + cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) + + return nil, nil //nolint +} diff --git a/vendor/github.com/pion/dtls/v2/flight6handler.go b/vendor/github.com/pion/dtls/v2/flight6handler.go new file mode 100644 index 0000000000..c038904256 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/flight6handler.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +func flight6Parse(_ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { + _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + if !ok { + // No valid message received. Keep reading + return 0, nil, nil + } + + if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { + return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil + } + + // Other party may re-transmit the last flight. Keep state to be flight6. + return flight6, nil, nil +} + +func flight6Generate(_ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { + var pkts []*packet + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + }, + Content: &protocol.ChangeCipherSpec{}, + }, + }) + + if len(state.localVerifyData) == 0 { + plainText := cache.pullAndMerge( + handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, + handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, + handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, + ) + + var err error + state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) + if err != nil { + return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err + } + } + + pkts = append(pkts, + &packet{ + record: &recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + Epoch: 1, + }, + Content: &handshake.Handshake{ + Message: &handshake.MessageFinished{ + VerifyData: state.localVerifyData, + }, + }, + }, + shouldWrapCID: len(state.remoteConnectionID) > 0, + shouldEncrypt: true, + resetLocalSequenceNumber: true, + }, + ) + return pkts, nil, nil +} diff --git a/vendor/github.com/pion/dtls/v2/fragment_buffer.go b/vendor/github.com/pion/dtls/v2/fragment_buffer.go new file mode 100644 index 0000000000..fb5af6c3db --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/fragment_buffer.go @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// 2 megabytes +const fragmentBufferMaxSize = 2000000 + +type fragment struct { + recordLayerHeader recordlayer.Header + handshakeHeader handshake.Header + data []byte +} + +type fragmentBuffer struct { + // map of MessageSequenceNumbers that hold slices of fragments + cache map[uint16][]*fragment + + currentMessageSequenceNumber uint16 +} + +func newFragmentBuffer() *fragmentBuffer { + return &fragmentBuffer{cache: map[uint16][]*fragment{}} +} + +// current total size of buffer +func (f *fragmentBuffer) size() int { + size := 0 + for i := range f.cache { + for j := range f.cache[i] { + size += len(f.cache[i][j].data) + } + } + return size +} + +// Attempts to push a DTLS packet to the fragmentBuffer +// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled +// when an error returns it is fatal, and the DTLS connection should be stopped +func (f *fragmentBuffer) push(buf []byte) (bool, error) { + if f.size()+len(buf) >= fragmentBufferMaxSize { + return false, errFragmentBufferOverflow + } + + frag := new(fragment) + if err := frag.recordLayerHeader.Unmarshal(buf); err != nil { + return false, err + } + + // fragment isn't a handshake, we don't need to handle it + if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake { + return false, nil + } + + for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) { + if err := frag.handshakeHeader.Unmarshal(buf); err != nil { + return false, err + } + + if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { + f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{} + } + + // end index should be the length of handshake header but if the handshake + // was fragmented, we should keep them all + end := int(handshake.HeaderLength + frag.handshakeHeader.Length) + if size := len(buf); end > size { + end = size + } + + // Discard all headers, when rebuilding the packet we will re-build + frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...) + f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag) + buf = buf[end:] + } + + return true, nil +} + +func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { + frags, ok := f.cache[f.currentMessageSequenceNumber] + if !ok { + return nil, 0 + } + + // Go doesn't support recursive lambdas + var appendMessage func(targetOffset uint32) bool + + rawMessage := []byte{} + appendMessage = func(targetOffset uint32) bool { + for _, f := range frags { + if f.handshakeHeader.FragmentOffset == targetOffset { + fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength) + if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 { + if !appendMessage(fragmentEnd) { + return false + } + } + + rawMessage = append(f.data, rawMessage...) + return true + } + } + return false + } + + // Recursively collect up + if !appendMessage(0) { + return nil, 0 + } + + firstHeader := frags[0].handshakeHeader + firstHeader.FragmentOffset = 0 + firstHeader.FragmentLength = firstHeader.Length + + rawHeader, err := firstHeader.Marshal() + if err != nil { + return nil, 0 + } + + messageEpoch := frags[0].recordLayerHeader.Epoch + + delete(f.cache, f.currentMessageSequenceNumber) + f.currentMessageSequenceNumber++ + return append(rawHeader, rawMessage...), messageEpoch +} diff --git a/vendor/github.com/pion/dtls/v2/handshaker.go b/vendor/github.com/pion/dtls/v2/handshaker.go new file mode 100644 index 0000000000..46fbd38bda --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/handshaker.go @@ -0,0 +1,347 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "sync" + "time" + + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/crypto/signaturehash" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/logging" +) + +// [RFC6347 Section-4.2.4] +// +-----------+ +// +---> | PREPARING | <--------------------+ +// | +-----------+ | +// | | | +// | | Buffer next flight | +// | | | +// | \|/ | +// | +-----------+ | +// | | SENDING |<------------------+ | Send +// | +-----------+ | | HelloRequest +// Receive | | | | +// next | | Send flight | | or +// flight | +--------+ | | +// | | | Set retransmit timer | | Receive +// | | \|/ | | HelloRequest +// | | +-----------+ | | Send +// +--)--| WAITING |-------------------+ | ClientHello +// | | +-----------+ Timer expires | | +// | | | | | +// | | +------------------------+ | +// Receive | | Send Read retransmit | +// last | | last | +// flight | | flight | +// | | | +// \|/\|/ | +// +-----------+ | +// | FINISHED | -------------------------------+ +// +-----------+ +// | /|\ +// | | +// +---+ +// Read retransmit +// Retransmit last flight + +type handshakeState uint8 + +const ( + handshakeErrored handshakeState = iota + handshakePreparing + handshakeSending + handshakeWaiting + handshakeFinished +) + +func (s handshakeState) String() string { + switch s { + case handshakeErrored: + return "Errored" + case handshakePreparing: + return "Preparing" + case handshakeSending: + return "Sending" + case handshakeWaiting: + return "Waiting" + case handshakeFinished: + return "Finished" + default: + return "Unknown" + } +} + +type handshakeFSM struct { + currentFlight flightVal + flights []*packet + retransmit bool + state *State + cache *handshakeCache + cfg *handshakeConfig + closed chan struct{} +} + +type handshakeConfig struct { + localPSKCallback PSKCallback + localPSKIdentityHint []byte + localCipherSuites []CipherSuite // Available CipherSuites + localSignatureSchemes []signaturehash.Algorithm // Available signature schemes + extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension + localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support + serverName string + supportedProtocols []string + clientAuth ClientAuthType // If we are a client should we request a client certificate + localCertificates []tls.Certificate + nameToCertificate map[string]*tls.Certificate + insecureSkipVerify bool + verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + verifyConnection func(*State) error + sessionStore SessionStore + rootCAs *x509.CertPool + clientCAs *x509.CertPool + retransmitInterval time.Duration + customCipherSuites func() []CipherSuite + ellipticCurves []elliptic.Curve + insecureSkipHelloVerify bool + connectionIDGenerator func() []byte + + onFlightState func(flightVal, handshakeState) + log logging.LeveledLogger + keyLogWriter io.Writer + + localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) + localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) + + initialEpoch uint16 + + mu sync.Mutex +} + +type flightConn interface { + notify(ctx context.Context, level alert.Level, desc alert.Description) error + writePackets(context.Context, []*packet) error + recvHandshake() <-chan chan struct{} + setLocalEpoch(epoch uint16) + handleQueuedPackets(context.Context) error + sessionKey() []byte +} + +func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) { + if c.keyLogWriter == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + _, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret))) + if err != nil { + c.log.Debugf("failed to write key log file: %s", err) + } +} + +func srvCliStr(isClient bool) string { + if isClient { + return "client" + } + return "server" +} + +func newHandshakeFSM( + s *State, cache *handshakeCache, cfg *handshakeConfig, + initialFlight flightVal, +) *handshakeFSM { + return &handshakeFSM{ + currentFlight: initialFlight, + state: s, + cache: cache, + cfg: cfg, + closed: make(chan struct{}), + } +} + +func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error { + state := initialState + defer func() { + close(s.closed) + }() + for { + s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) + if s.cfg.onFlightState != nil { + s.cfg.onFlightState(s.currentFlight, state) + } + var err error + switch state { + case handshakePreparing: + state, err = s.prepare(ctx, c) + case handshakeSending: + state, err = s.send(ctx, c) + case handshakeWaiting: + state, err = s.wait(ctx, c) + case handshakeFinished: + state, err = s.finish(ctx, c) + default: + return errInvalidFSMTransition + } + if err != nil { + return err + } + } +} + +func (s *handshakeFSM) Done() <-chan struct{} { + return s.closed +} + +func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) { + s.flights = nil + // Prepare flights + var ( + a *alert.Alert + err error + pkts []*packet + ) + gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() + if errFlight != nil { + err = errFlight + a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} + } else { + pkts, a, err = gen(c, s.state, s.cache, s.cfg) + s.retransmit = retransmit + } + if a != nil { + if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil { + if err != nil { + err = alertErr + } + } + } + if err != nil { + return handshakeErrored, err + } + + s.flights = pkts + epoch := s.cfg.initialEpoch + nextEpoch := epoch + for _, p := range s.flights { + p.record.Header.Epoch += epoch + if p.record.Header.Epoch > nextEpoch { + nextEpoch = p.record.Header.Epoch + } + if h, ok := p.record.Content.(*handshake.Handshake); ok { + h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) + s.state.handshakeSendSequence++ + } + } + if epoch != nextEpoch { + s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) + c.setLocalEpoch(nextEpoch) + } + return handshakeSending, nil +} + +func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { + // Send flights + if err := c.writePackets(ctx, s.flights); err != nil { + return handshakeErrored, err + } + + if s.currentFlight.isLastSendFlight() { + return handshakeFinished, nil + } + return handshakeWaiting, nil +} + +func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit + parse, errFlight := s.currentFlight.getFlightParser() + if errFlight != nil { + if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { + return handshakeErrored, alertErr + } + return handshakeErrored, errFlight + } + + retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) + for { + select { + case done := <-c.recvHandshake(): + nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) + close(done) + if alert != nil { + if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err != nil { + err = alertErr + } + } + } + if err != nil { + return handshakeErrored, err + } + if nextFlight == 0 { + break + } + s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String()) + if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { + return handshakeFinished, nil + } + s.currentFlight = nextFlight + return handshakePreparing, nil + + case <-retransmitTimer.C: + if !s.retransmit { + return handshakeWaiting, nil + } + return handshakeSending, nil + case <-ctx.Done(): + return handshakeErrored, ctx.Err() + } + } +} + +func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { + parse, errFlight := s.currentFlight.getFlightParser() + if errFlight != nil { + if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { + return handshakeErrored, alertErr + } + return handshakeErrored, errFlight + } + + retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) + select { + case done := <-c.recvHandshake(): + nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) + close(done) + if alert != nil { + if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { + if err != nil { + err = alertErr + } + } + } + if err != nil { + return handshakeErrored, err + } + if nextFlight == 0 { + break + } + if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { + return handshakeFinished, nil + } + <-retransmitTimer.C + // Retransmit last flight + return handshakeSending, nil + + case <-ctx.Done(): + return handshakeErrored, ctx.Err() + } + return handshakeFinished, nil +} diff --git a/vendor/github.com/pion/dtls/v2/internal/ciphersuite/aes_ccm.go b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/aes_ccm.go new file mode 100644 index 0000000000..ee3cca9031 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/aes_ccm.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// AesCcm is a base class used by multiple AES-CCM Ciphers +type AesCcm struct { + ccm atomic.Value // *cryptoCCM + clientCertificateType clientcertificate.Type + id ID + psk bool + keyExchangeAlgorithm KeyExchangeAlgorithm + cryptoCCMTagLen ciphersuite.CCMTagLen + ecc bool +} + +// CertificateType returns what type of certificate this CipherSuite exchanges +func (c *AesCcm) CertificateType() clientcertificate.Type { + return c.clientCertificateType +} + +// ID returns the ID of the CipherSuite +func (c *AesCcm) ID() ID { + return c.id +} + +func (c *AesCcm) String() string { + return c.id.String() +} + +// ECC uses Elliptic Curve Cryptography +func (c *AesCcm) ECC() bool { + return c.ecc +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +func (c *AesCcm) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return c.keyExchangeAlgorithm +} + +// HashFunc returns the hashing func for this CipherSuite +func (c *AesCcm) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake +func (c *AesCcm) AuthenticationType() AuthenticationType { + if c.psk { + return AuthenticationTypePreSharedKey + } + return AuthenticationTypeCertificate +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets +func (c *AesCcm) IsInitialized() bool { + return c.ccm.Load() != nil +} + +// Init initializes the internal Cipher with keying material +func (c *AesCcm) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool, prfKeyLen int) error { + const ( + prfMacLen = 0 + prfIvLen = 4 + ) + + keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + if err != nil { + return err + } + + var ccm *ciphersuite.CCM + if isClient { + ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV) + } else { + ccm, err = ciphersuite.NewCCM(c.cryptoCCMTagLen, keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV) + } + c.ccm.Store(ccm) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer +func (c *AesCcm) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer +func (c *AesCcm) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.ccm.Load().(*ciphersuite.CCM) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go new file mode 100644 index 0000000000..362370b987 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go @@ -0,0 +1,108 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// TLSEcdheEcdsaWithAes128GcmSha256 represents a TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuite +type TLSEcdheEcdsaWithAes128GcmSha256 struct { + gcm atomic.Value // *cryptoGCM +} + +// CertificateType returns what type of certficate this CipherSuite exchanges +func (c *TLSEcdheEcdsaWithAes128GcmSha256) CertificateType() clientcertificate.Type { + return clientcertificate.ECDSASign +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +func (c *TLSEcdheEcdsaWithAes128GcmSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return KeyExchangeAlgorithmEcdhe +} + +// ECC uses Elliptic Curve Cryptography +func (c *TLSEcdheEcdsaWithAes128GcmSha256) ECC() bool { + return true +} + +// ID returns the ID of the CipherSuite +func (c *TLSEcdheEcdsaWithAes128GcmSha256) ID() ID { + return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 +} + +func (c *TLSEcdheEcdsaWithAes128GcmSha256) String() string { + return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" +} + +// HashFunc returns the hashing func for this CipherSuite +func (c *TLSEcdheEcdsaWithAes128GcmSha256) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake +func (c *TLSEcdheEcdsaWithAes128GcmSha256) AuthenticationType() AuthenticationType { + return AuthenticationTypeCertificate +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets +func (c *TLSEcdheEcdsaWithAes128GcmSha256) IsInitialized() bool { + return c.gcm.Load() != nil +} + +func (c *TLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool, prfMacLen, prfKeyLen, prfIvLen int, hashFunc func() hash.Hash) error { + keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, hashFunc) + if err != nil { + return err + } + + var gcm *ciphersuite.GCM + if isClient { + gcm, err = ciphersuite.NewGCM(keys.ClientWriteKey, keys.ClientWriteIV, keys.ServerWriteKey, keys.ServerWriteIV) + } else { + gcm, err = ciphersuite.NewGCM(keys.ServerWriteKey, keys.ServerWriteIV, keys.ClientWriteKey, keys.ClientWriteIV) + } + c.gcm.Store(gcm) + return err +} + +// Init initializes the internal Cipher with keying material +func (c *TLSEcdheEcdsaWithAes128GcmSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 0 + prfKeyLen = 16 + prfIvLen = 4 + ) + + return c.init(masterSecret, clientRandom, serverRandom, isClient, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) +} + +// Encrypt encrypts a single TLS RecordLayer +func (c *TLSEcdheEcdsaWithAes128GcmSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer +func (c *TLSEcdheEcdsaWithAes128GcmSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.gcm.Load().(*ciphersuite.GCM) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go new file mode 100644 index 0000000000..07ad66fd1a --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha1" //nolint: gosec,gci + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// TLSEcdheEcdsaWithAes256CbcSha represents a TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuite +type TLSEcdheEcdsaWithAes256CbcSha struct { + cbc atomic.Value // *cryptoCBC +} + +// CertificateType returns what type of certficate this CipherSuite exchanges +func (c *TLSEcdheEcdsaWithAes256CbcSha) CertificateType() clientcertificate.Type { + return clientcertificate.ECDSASign +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +func (c *TLSEcdheEcdsaWithAes256CbcSha) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return KeyExchangeAlgorithmEcdhe +} + +// ECC uses Elliptic Curve Cryptography +func (c *TLSEcdheEcdsaWithAes256CbcSha) ECC() bool { + return true +} + +// ID returns the ID of the CipherSuite +func (c *TLSEcdheEcdsaWithAes256CbcSha) ID() ID { + return TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA +} + +func (c *TLSEcdheEcdsaWithAes256CbcSha) String() string { + return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA" +} + +// HashFunc returns the hashing func for this CipherSuite +func (c *TLSEcdheEcdsaWithAes256CbcSha) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake +func (c *TLSEcdheEcdsaWithAes256CbcSha) AuthenticationType() AuthenticationType { + return AuthenticationTypeCertificate +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets +func (c *TLSEcdheEcdsaWithAes256CbcSha) IsInitialized() bool { + return c.cbc.Load() != nil +} + +// Init initializes the internal Cipher with keying material +func (c *TLSEcdheEcdsaWithAes256CbcSha) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 20 + prfKeyLen = 32 + prfIvLen = 16 + ) + + keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + if err != nil { + return err + } + + var cbc *ciphersuite.CBC + if isClient { + cbc, err = ciphersuite.NewCBC( + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + sha1.New, + ) + } else { + cbc, err = ciphersuite.NewCBC( + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + sha1.New, + ) + } + c.cbc.Store(cbc) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer +func (c *TLSEcdheEcdsaWithAes256CbcSha) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer +func (c *TLSEcdheEcdsaWithAes256CbcSha) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go new file mode 100644 index 0000000000..10cc58c0ba --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_ecdhe_psk_with_aes_128_cbc_sha256.go @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// TLSEcdhePskWithAes128CbcSha256 implements the TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuite +type TLSEcdhePskWithAes128CbcSha256 struct { + cbc atomic.Value // *cryptoCBC +} + +// NewTLSEcdhePskWithAes128CbcSha256 creates TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 cipher. +func NewTLSEcdhePskWithAes128CbcSha256() *TLSEcdhePskWithAes128CbcSha256 { + return &TLSEcdhePskWithAes128CbcSha256{} +} + +// CertificateType returns what type of certificate this CipherSuite exchanges +func (c *TLSEcdhePskWithAes128CbcSha256) CertificateType() clientcertificate.Type { + return clientcertificate.Type(0) +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +func (c *TLSEcdhePskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return (KeyExchangeAlgorithmPsk | KeyExchangeAlgorithmEcdhe) +} + +// ECC uses Elliptic Curve Cryptography +func (c *TLSEcdhePskWithAes128CbcSha256) ECC() bool { + return true +} + +// ID returns the ID of the CipherSuite +func (c *TLSEcdhePskWithAes128CbcSha256) ID() ID { + return TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 +} + +func (c *TLSEcdhePskWithAes128CbcSha256) String() string { + return "TLS-ECDHE-PSK-WITH-AES-128-CBC-SHA256" +} + +// HashFunc returns the hashing func for this CipherSuite +func (c *TLSEcdhePskWithAes128CbcSha256) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake +func (c *TLSEcdhePskWithAes128CbcSha256) AuthenticationType() AuthenticationType { + return AuthenticationTypePreSharedKey +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets +func (c *TLSEcdhePskWithAes128CbcSha256) IsInitialized() bool { + return c.cbc.Load() != nil +} + +// Init initializes the internal Cipher with keying material +func (c *TLSEcdhePskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 32 + prfKeyLen = 16 + prfIvLen = 16 + ) + + keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + if err != nil { + return err + } + + var cbc *ciphersuite.CBC + if isClient { + cbc, err = ciphersuite.NewCBC( + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + c.HashFunc(), + ) + } else { + cbc, err = ciphersuite.NewCBC( + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + c.HashFunc(), + ) + } + c.cbc.Store(cbc) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer +func (c *TLSEcdhePskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { // !c.isInitialized() + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer +func (c *TLSEcdhePskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { // !c.isInitialized() + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go new file mode 100644 index 0000000000..dea0dfc75e --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/ciphersuite/tls_psk_with_aes_128_cbc_sha256.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/sha256" + "fmt" + "hash" + "sync/atomic" + + "github.com/pion/dtls/v2/pkg/crypto/ciphersuite" + "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// TLSPskWithAes128CbcSha256 implements the TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuite +type TLSPskWithAes128CbcSha256 struct { + cbc atomic.Value // *cryptoCBC +} + +// CertificateType returns what type of certificate this CipherSuite exchanges +func (c *TLSPskWithAes128CbcSha256) CertificateType() clientcertificate.Type { + return clientcertificate.Type(0) +} + +// KeyExchangeAlgorithm controls what key exchange algorithm is using during the handshake +func (c *TLSPskWithAes128CbcSha256) KeyExchangeAlgorithm() KeyExchangeAlgorithm { + return KeyExchangeAlgorithmPsk +} + +// ECC uses Elliptic Curve Cryptography +func (c *TLSPskWithAes128CbcSha256) ECC() bool { + return false +} + +// ID returns the ID of the CipherSuite +func (c *TLSPskWithAes128CbcSha256) ID() ID { + return TLS_PSK_WITH_AES_128_CBC_SHA256 +} + +func (c *TLSPskWithAes128CbcSha256) String() string { + return "TLS_PSK_WITH_AES_128_CBC_SHA256" +} + +// HashFunc returns the hashing func for this CipherSuite +func (c *TLSPskWithAes128CbcSha256) HashFunc() func() hash.Hash { + return sha256.New +} + +// AuthenticationType controls what authentication method is using during the handshake +func (c *TLSPskWithAes128CbcSha256) AuthenticationType() AuthenticationType { + return AuthenticationTypePreSharedKey +} + +// IsInitialized returns if the CipherSuite has keying material and can +// encrypt/decrypt packets +func (c *TLSPskWithAes128CbcSha256) IsInitialized() bool { + return c.cbc.Load() != nil +} + +// Init initializes the internal Cipher with keying material +func (c *TLSPskWithAes128CbcSha256) Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { + const ( + prfMacLen = 32 + prfKeyLen = 16 + prfIvLen = 16 + ) + + keys, err := prf.GenerateEncryptionKeys(masterSecret, clientRandom, serverRandom, prfMacLen, prfKeyLen, prfIvLen, c.HashFunc()) + if err != nil { + return err + } + + var cbc *ciphersuite.CBC + if isClient { + cbc, err = ciphersuite.NewCBC( + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + c.HashFunc(), + ) + } else { + cbc, err = ciphersuite.NewCBC( + keys.ServerWriteKey, keys.ServerWriteIV, keys.ServerMACKey, + keys.ClientWriteKey, keys.ClientWriteIV, keys.ClientMACKey, + c.HashFunc(), + ) + } + c.cbc.Store(cbc) + + return err +} + +// Encrypt encrypts a single TLS RecordLayer +func (c *TLSPskWithAes128CbcSha256) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to encrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Encrypt(pkt, raw) +} + +// Decrypt decrypts a single TLS RecordLayer +func (c *TLSPskWithAes128CbcSha256) Decrypt(h recordlayer.Header, raw []byte) ([]byte, error) { + cipherSuite, ok := c.cbc.Load().(*ciphersuite.CBC) + if !ok { + return nil, fmt.Errorf("%w, unable to decrypt", errCipherSuiteNotInit) + } + + return cipherSuite.Decrypt(h, raw) +} diff --git a/vendor/github.com/pion/dtls/v2/internal/net/buffer.go b/vendor/github.com/pion/dtls/v2/internal/net/buffer.go new file mode 100644 index 0000000000..9ab290e4cd --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/net/buffer.go @@ -0,0 +1,235 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package net implements DTLS specific networking primitives. +// NOTE: this package is an adaption of pion/transport/packetio that allows for +// storing a remote address alongside each packet in the buffer and implements +// relevant methods of net.PacketConn. If possible, the updates made in this +// repository will be reflected back upstream. If not, it is likely that this +// will be moved to a public package in this repository. +// +// This package was migrated from pion/transport/packetio at +// https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb +package net + +import ( + "bytes" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/pion/transport/v3/deadline" +) + +// ErrTimeout indicates that deadline was reached before operation could be +// completed. +var ErrTimeout = errors.New("buffer: i/o timeout") + +// AddrPacket is a packet payload and the associated remote address from which +// it was received. +type AddrPacket struct { + addr net.Addr + data bytes.Buffer +} + +// PacketBuffer is a circular buffer for network packets. Each slot in the +// buffer contains the remote address from which the packet was received, as +// well as the packet data. +type PacketBuffer struct { + mutex sync.Mutex + + packets []AddrPacket + write, read int + + // full indicates whether the buffer is full, which is needed to distinguish + // when the write pointer and read pointer are at the same index. + full bool + + notify chan struct{} + closed bool + + readDeadline *deadline.Deadline +} + +// NewPacketBuffer creates a new PacketBuffer. +func NewPacketBuffer() *PacketBuffer { + return &PacketBuffer{ + readDeadline: deadline.New(), + // In the narrow context in which this package is currently used, there + // will always be at least one packet written to the buffer. Therefore, + // we opt to allocate with size of 1 during construction, rather than + // waiting until that first packet is written. + packets: make([]AddrPacket, 1), + full: false, + } +} + +// WriteTo writes a single packet to the buffer. The supplied address will +// remain associated with the packet. +func (b *PacketBuffer) WriteTo(p []byte, addr net.Addr) (int, error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + return 0, io.ErrClosedPipe + } + + var notify chan struct{} + if b.notify != nil { + notify = b.notify + b.notify = nil + } + + // Check to see if we are full. + if b.full { + // If so, grow AddrPacket buffer. + var newSize int + if len(b.packets) < 128 { + // Double the number of packets. + newSize = len(b.packets) * 2 + } else { + // Increase the number of packets by 25%. + newSize = 5 * len(b.packets) / 4 + } + newBuf := make([]AddrPacket, newSize) + var n int + if b.read < b.write { + n = copy(newBuf, b.packets[b.read:b.write]) + } else { + n = copy(newBuf, b.packets[b.read:]) + n += copy(newBuf[n:], b.packets[:b.write]) + } + + b.packets = newBuf + + // Update write pointer to point to new location and mark buffer as not + // full. + b.write = n + b.full = false + } + + // Store the packet at the write pointer. + packet := &b.packets[b.write] + packet.data.Reset() + n, err := packet.data.Write(p) + if err != nil { + b.mutex.Unlock() + return n, err + } + packet.addr = addr + + // Increment write pointer. + b.write++ + + // If the write pointer is equal to the length of the buffer, wrap around. + if len(b.packets) == b.write { + b.write = 0 + } + + // If a write resulted in making write and read pointers equivalent, then we + // are full. + if b.write == b.read { + b.full = true + } + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return n, nil +} + +// ReadFrom reads a single packet from the buffer, or blocks until one is +// available. +func (b *PacketBuffer) ReadFrom(packet []byte) (n int, addr net.Addr, err error) { + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + default: + } + + for { + b.mutex.Lock() + + if b.read != b.write || b.full { + ap := b.packets[b.read] + if len(packet) < ap.data.Len() { + b.mutex.Unlock() + return 0, nil, io.ErrShortBuffer + } + + // Copy packet data from buffer. + n, err := ap.data.Read(packet) + if err != nil { + b.mutex.Unlock() + return n, nil, err + } + + // Advance read pointer. + b.read++ + if len(b.packets) == b.read { + b.read = 0 + } + + // If we were full before reading and have successfully read, we are + // no longer full. + if b.full { + b.full = false + } + + b.mutex.Unlock() + + return n, ap.addr, nil + } + + if b.closed { + b.mutex.Unlock() + return 0, nil, io.EOF + } + + if b.notify == nil { + b.notify = make(chan struct{}) + } + notify := b.notify + b.mutex.Unlock() + + select { + case <-b.readDeadline.Done(): + return 0, nil, ErrTimeout + case <-notify: + } + } +} + +// Close closes the buffer, allowing unread packets to be read, but erroring on +// any new writes. +func (b *PacketBuffer) Close() (err error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + return nil + } + + notify := b.notify + b.notify = nil + b.closed = true + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return nil +} + +// SetReadDeadline sets the read deadline for the buffer. +func (b *PacketBuffer) SetReadDeadline(t time.Time) error { + b.readDeadline.Set(t) + return nil +} diff --git a/vendor/github.com/pion/dtls/v2/internal/net/udp/packet_conn.go b/vendor/github.com/pion/dtls/v2/internal/net/udp/packet_conn.go new file mode 100644 index 0000000000..7dafbe23e7 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/net/udp/packet_conn.go @@ -0,0 +1,407 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package udp implements DTLS specific UDP networking primitives. +// NOTE: this package is an adaption of pion/transport/udp that allows for +// routing datagrams based on identifiers other than the remote address. The +// primary use case for this functionality is routing based on DTLS connection +// IDs. In order to allow for consumers of this package to treat connections as +// generic net.PackageConn, routing and identitier establishment is based on +// custom introspecion of datagrams, rather than direct intervention by +// consumers. If possible, the updates made in this repository will be reflected +// back upstream. If not, it is likely that this will be moved to a public +// package in this repository. +// +// This package was migrated from pion/transport/udp at +// https://github.com/pion/transport/commit/6890c795c807a617c054149eee40a69d7fdfbfdb +package udp + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + idtlsnet "github.com/pion/dtls/v2/internal/net" + dtlsnet "github.com/pion/dtls/v2/pkg/net" + "github.com/pion/transport/v3/deadline" +) + +const ( + receiveMTU = 8192 + defaultListenBacklog = 128 // same as Linux default +) + +// Typed errors +var ( + ErrClosedListener = errors.New("udp: listener closed") + ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") +) + +// listener augments a connection-oriented Listener over a UDP PacketConn +type listener struct { + pConn *net.UDPConn + + accepting atomic.Value // bool + acceptCh chan *PacketConn + doneCh chan struct{} + doneOnce sync.Once + acceptFilter func([]byte) bool + datagramRouter func([]byte) (string, bool) + connIdentifier func([]byte) (string, bool) + + connLock sync.Mutex + conns map[string]*PacketConn + connWG sync.WaitGroup + + readWG sync.WaitGroup + errClose atomic.Value // error + + readDoneCh chan struct{} + errRead atomic.Value // error +} + +// Accept waits for and returns the next connection to the listener. +func (l *listener) Accept() (net.PacketConn, net.Addr, error) { + select { + case c := <-l.acceptCh: + l.connWG.Add(1) + return c, c.raddr, nil + + case <-l.readDoneCh: + err, _ := l.errRead.Load().(error) + return nil, nil, err + + case <-l.doneCh: + return nil, nil, ErrClosedListener + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *listener) Close() error { + var err error + l.doneOnce.Do(func() { + l.accepting.Store(false) + close(l.doneCh) + + l.connLock.Lock() + // Close unaccepted connections + lclose: + for { + select { + case c := <-l.acceptCh: + close(c.doneCh) + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(l.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it + // from the connection map. + if c.rmraddr.Load() == nil { + delete(l.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + default: + break lclose + } + } + nConns := len(l.conns) + l.connLock.Unlock() + + l.connWG.Done() + + if nConns == 0 { + // Wait if this is the final connection. + l.readWG.Wait() + if errClose, ok := l.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + }) + + return err +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.pConn.LocalAddr() +} + +// ListenConfig stores options for listening to an address. +type ListenConfig struct { + // Backlog defines the maximum length of the queue of pending + // connections. It is equivalent of the backlog argument of + // POSIX listen function. + // If a connection request arrives when the queue is full, + // the request will be silently discarded, unlike TCP. + // Set zero to use default value 128 which is same as Linux default. + Backlog int + + // AcceptFilter determines whether the new conn should be made for + // the incoming packet. If not set, any packet creates new conn. + AcceptFilter func([]byte) bool + + // DatagramRouter routes an incoming datagram to a connection by extracting + // an identifier from the its paylod + DatagramRouter func([]byte) (string, bool) + + // ConnectionIdentifier extracts an identifier from an outgoing packet. If + // the identifier is not already associated with the connection, it will be + // added. + ConnectionIdentifier func([]byte) (string, bool) +} + +// Listen creates a new listener based on the ListenConfig. +func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + if lc.Backlog == 0 { + lc.Backlog = defaultListenBacklog + } + + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + + l := &listener{ + pConn: conn, + acceptCh: make(chan *PacketConn, lc.Backlog), + conns: make(map[string]*PacketConn), + doneCh: make(chan struct{}), + acceptFilter: lc.AcceptFilter, + datagramRouter: lc.DatagramRouter, + connIdentifier: lc.ConnectionIdentifier, + readDoneCh: make(chan struct{}), + } + + l.accepting.Store(true) + l.connWG.Add(1) + l.readWG.Add(2) // wait readLoop and Close execution routine + + go l.readLoop() + go func() { + l.connWG.Wait() + if err := l.pConn.Close(); err != nil { + l.errClose.Store(err) + } + l.readWG.Done() + }() + + return l, nil +} + +// Listen creates a new listener using default ListenConfig. +func Listen(network string, laddr *net.UDPAddr) (dtlsnet.PacketListener, error) { + return (&ListenConfig{}).Listen(network, laddr) +} + +// readLoop dispatches packets to the proper connection, creating a new one if +// necessary, until all connections are closed. +func (l *listener) readLoop() { + defer l.readWG.Done() + defer close(l.readDoneCh) + + buf := make([]byte, receiveMTU) + + for { + n, raddr, err := l.pConn.ReadFrom(buf) + if err != nil { + l.errRead.Store(err) + return + } + conn, ok, err := l.getConn(raddr, buf[:n]) + if err != nil { + continue + } + if ok { + _, _ = conn.buffer.WriteTo(buf[:n], raddr) + } + } +} + +// getConn gets an existing connection or creates a new one. +func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error) { + l.connLock.Lock() + defer l.connLock.Unlock() + // If we have a custom resolver, use it. + if l.datagramRouter != nil { + if id, ok := l.datagramRouter(buf); ok { + if conn, ok := l.conns[id]; ok { + return conn, true, nil + } + } + } + + // If we don't have a custom resolver, or we were unable to find an + // associated connection, fall back to remote address. + conn, ok := l.conns[raddr.String()] + if !ok { + if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok { + return nil, false, ErrClosedListener + } + if l.acceptFilter != nil { + if !l.acceptFilter(buf) { + return nil, false, nil + } + } + conn = l.newPacketConn(raddr) + select { + case l.acceptCh <- conn: + l.conns[raddr.String()] = conn + default: + return nil, false, ErrListenQueueExceeded + } + } + return conn, true, nil +} + +// PacketConn is a net.PacketConn implementation that is able to dictate its +// routing ID via an alternate identifier from its remote address. Internal +// buffering is performed for reads, and writes are passed through to the +// underlying net.PacketConn. +type PacketConn struct { + listener *listener + + raddr net.Addr + rmraddr atomic.Value // bool + id atomic.Value // string + + buffer *idtlsnet.PacketBuffer + + doneCh chan struct{} + doneOnce sync.Once + + writeDeadline *deadline.Deadline +} + +// newPacketConn constructs a new PacketConn. +func (l *listener) newPacketConn(raddr net.Addr) *PacketConn { + return &PacketConn{ + listener: l, + raddr: raddr, + buffer: idtlsnet.NewPacketBuffer(), + doneCh: make(chan struct{}), + writeDeadline: deadline.New(), + } +} + +// ReadFrom reads a single packet payload and its associated remote address from +// the underlying buffer. +func (c *PacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + return c.buffer.ReadFrom(p) +} + +// WriteTo writes len(p) bytes from p to the specified address. +func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + // If we have a connection identifier, check to see if the outgoing packet + // sets it. + if c.listener.connIdentifier != nil { + id := c.id.Load() + // Only update establish identifier if we haven't already done so. + if id == nil { + candidate, ok := c.listener.connIdentifier(p) + // If we have an identifier, add entry to connection map. + if ok { + c.listener.connLock.Lock() + c.listener.conns[candidate] = c + c.listener.connLock.Unlock() + c.id.Store(candidate) + } + } + // If we are writing to a remote address that differs from the initial, + // we have an alternate identifier established, and we haven't already + // freed the remote address, free the remote address to be used by + // another connection. + // Note: this strategy results in holding onto a remote address after it + // is potentially no longer in use by the client. However, releasing + // earlier means that we could miss some packets that should have been + // routed to this connection. Ideally, we would drop the connection + // entry for the remote address as soon as the client starts sending + // using an alternate identifier, but in practice this proves + // challenging because any client could spoof a connection identifier, + // resulting in the remote address entry being dropped prior to the + // "real" client transitioning to sending using the alternate + // identifier. + if id != nil && c.rmraddr.Load() == nil && addr.String() != c.raddr.String() { + c.listener.connLock.Lock() + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + c.listener.connLock.Unlock() + } + } + + select { + case <-c.writeDeadline.Done(): + return 0, context.DeadlineExceeded + default: + } + return c.listener.pConn.WriteTo(p, addr) +} + +// Close closes the conn and releases any Read calls +func (c *PacketConn) Close() error { + var err error + c.doneOnce.Do(func() { + c.listener.connWG.Done() + close(c.doneCh) + c.listener.connLock.Lock() + // If we have an alternate identifier, remove it from the connection + // map. + if id := c.id.Load(); id != nil { + delete(c.listener.conns, id.(string)) //nolint:forcetypeassert + } + // If we haven't already removed the remote address, remove it from the + // connection map. + if c.rmraddr.Load() == nil { + delete(c.listener.conns, c.raddr.String()) + c.rmraddr.Store(true) + } + nConns := len(c.listener.conns) + c.listener.connLock.Unlock() + + if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok { + // Wait if this is the final connection + c.listener.readWG.Wait() + if errClose, ok := c.listener.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + + if errBuf := c.buffer.Close(); errBuf != nil && err == nil { + err = errBuf + } + }) + + return err +} + +// LocalAddr implements net.PacketConn.LocalAddr. +func (c *PacketConn) LocalAddr() net.Addr { + return c.listener.pConn.LocalAddr() +} + +// SetDeadline implements net.PacketConn.SetDeadline. +func (c *PacketConn) SetDeadline(t time.Time) error { + c.writeDeadline.Set(t) + return c.SetReadDeadline(t) +} + +// SetReadDeadline implements net.PacketConn.SetReadDeadline. +func (c *PacketConn) SetReadDeadline(t time.Time) error { + return c.buffer.SetReadDeadline(t) +} + +// SetWriteDeadline implements net.PacketConn.SetWriteDeadline. +func (c *PacketConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + // Write deadline of underlying connection should not be changed + // since the connection can be shared. + return nil +} diff --git a/vendor/github.com/pion/dtls/v2/internal/util/util.go b/vendor/github.com/pion/dtls/v2/internal/util/util.go new file mode 100644 index 0000000000..382a0e1cdd --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/internal/util/util.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package util contains small helpers used across the repo +package util + +import ( + "encoding/binary" + + "golang.org/x/crypto/cryptobyte" +) + +// BigEndianUint24 returns the value of a big endian uint24 +func BigEndianUint24(raw []byte) uint32 { + if len(raw) < 3 { + return 0 + } + + rawCopy := make([]byte, 4) + copy(rawCopy[1:], raw) + return binary.BigEndian.Uint32(rawCopy) +} + +// PutBigEndianUint24 encodes a uint24 and places into out +func PutBigEndianUint24(out []byte, in uint32) { + tmp := make([]byte, 4) + binary.BigEndian.PutUint32(tmp, in) + copy(out, tmp[1:]) +} + +// PutBigEndianUint48 encodes a uint64 and places into out +func PutBigEndianUint48(out []byte, in uint64) { + tmp := make([]byte, 8) + binary.BigEndian.PutUint64(tmp, in) + copy(out, tmp[2:]) +} + +// Max returns the larger value +func Max(a, b int) int { + if a > b { + return a + } + return b +} + +// AddUint48 appends a big-endian, 48-bit value to the byte string. +// Remove if / when https://github.com/golang/crypto/pull/265 is merged +// upstream. +func AddUint48(b *cryptobyte.Builder, v uint64) { + b.AddBytes([]byte{byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}) +} diff --git a/vendor/github.com/pion/dtls/v2/listener.go b/vendor/github.com/pion/dtls/v2/listener.go new file mode 100644 index 0000000000..90dbbb427c --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/listener.go @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "net" + + "github.com/pion/dtls/v2/internal/net/udp" + dtlsnet "github.com/pion/dtls/v2/pkg/net" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// Listen creates a DTLS listener +func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) { + if err := validateConfig(config); err != nil { + return nil, err + } + + lc := udp.ListenConfig{ + AcceptFilter: func(packet []byte) bool { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return false + } + h := &recordlayer.Header{} + if err := h.Unmarshal(pkts[0]); err != nil { + return false + } + return h.ContentType == protocol.ContentTypeHandshake + }, + } + // If connection ID support is enabled, then they must be supported in + // routing. + if config.ConnectionIDGenerator != nil { + lc.DatagramRouter = cidDatagramRouter(len(config.ConnectionIDGenerator())) + lc.ConnectionIdentifier = cidConnIdentifier() + } + parent, err := lc.Listen(network, laddr) + if err != nil { + return nil, err + } + return &listener{ + config: config, + parent: parent, + }, nil +} + +// NewListener creates a DTLS listener which accepts connections from an inner Listener. +func NewListener(inner dtlsnet.PacketListener, config *Config) (net.Listener, error) { + if err := validateConfig(config); err != nil { + return nil, err + } + + return &listener{ + config: config, + parent: inner, + }, nil +} + +// listener represents a DTLS listener +type listener struct { + config *Config + parent dtlsnet.PacketListener +} + +// Accept waits for and returns the next connection to the listener. +// You have to either close or read on all connection that are created. +// Connection handshake will timeout using ConnectContextMaker in the Config. +// If you want to specify the timeout duration, set ConnectContextMaker. +func (l *listener) Accept() (net.Conn, error) { + c, raddr, err := l.parent.Accept() + if err != nil { + return nil, err + } + return Server(c, raddr, l.config) +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +// Already Accepted connections are not closed. +func (l *listener) Close() error { + return l.parent.Close() +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.parent.Addr() +} diff --git a/vendor/github.com/pion/dtls/v2/packet.go b/vendor/github.com/pion/dtls/v2/packet.go new file mode 100644 index 0000000000..052c33a19f --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/packet.go @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +type packet struct { + record *recordlayer.RecordLayer + shouldEncrypt bool + shouldWrapCID bool + resetLocalSequenceNumber bool +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/cbc.go b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/cbc.go new file mode 100644 index 0000000000..008a8365b8 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/cbc.go @@ -0,0 +1,227 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( //nolint:gci + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "encoding/binary" + "hash" + + "github.com/pion/dtls/v2/internal/util" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "golang.org/x/crypto/cryptobyte" +) + +// block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +type CBC struct { + writeCBC, readCBC cbcMode + writeMac, readMac []byte + h prf.HashFunc +} + +// NewCBC creates a DTLS CBC Cipher +func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, h prf.HashFunc) (*CBC, error) { + writeBlock, err := aes.NewCipher(localKey) + if err != nil { + return nil, err + } + + readBlock, err := aes.NewCipher(remoteKey) + if err != nil { + return nil, err + } + + writeCBC, ok := cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode) + if !ok { + return nil, errFailedToCast + } + + readCBC, ok := cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode) + if !ok { + return nil, errFailedToCast + } + + return &CBC{ + writeCBC: writeCBC, + writeMac: localMac, + + readCBC: readCBC, + readMac: remoteMac, + h: h, + }, nil +} + +// Encrypt encrypt a DTLS RecordLayer message +func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] + blockSize := c.writeCBC.BlockSize() + + // Generate + Append MAC + h := pkt.Header + + var err error + var mac []byte + if h.ContentType == protocol.ContentTypeConnectionID { + mac, err = c.hmacCID(h.Epoch, h.SequenceNumber, h.Version, payload, c.writeMac, c.h, h.ConnectionID) + } else { + mac, err = c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, payload, c.writeMac, c.h) + } + if err != nil { + return nil, err + } + payload = append(payload, mac...) + + // Generate + Append padding + padding := make([]byte, blockSize-len(payload)%blockSize) + paddingLen := len(padding) + for i := 0; i < paddingLen; i++ { + padding[i] = byte(paddingLen - 1) + } + payload = append(payload, padding...) + + // Generate IV + iv := make([]byte, blockSize) + if _, err := rand.Read(iv); err != nil { + return nil, err + } + + // Set IV + Encrypt + Prepend IV + c.writeCBC.SetIV(iv) + c.writeCBC.CryptBlocks(payload, payload) + payload = append(iv, payload...) + + // Prepend unencrypted header with encrypted payload + raw = append(raw, payload...) + + // Update recordLayer size to include IV+MAC+Padding + binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) + + return raw, nil +} + +// Decrypt decrypts a DTLS RecordLayer message +func (c *CBC) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) { + blockSize := c.readCBC.BlockSize() + mac := c.h() + + if err := h.Unmarshal(in); err != nil { + return nil, err + } + body := in[h.Size():] + + switch { + case h.ContentType == protocol.ContentTypeChangeCipherSpec: + // Nothing to encrypt with ChangeCipherSpec + return in, nil + case len(body)%blockSize != 0 || len(body) < blockSize+util.Max(mac.Size()+1, blockSize): + return nil, errNotEnoughRoomForNonce + } + + // Set + remove per record IV + c.readCBC.SetIV(body[:blockSize]) + body = body[blockSize:] + + // Decrypt + c.readCBC.CryptBlocks(body, body) + + // Padding+MAC needs to be checked in constant time + // Otherwise we reveal information about the level of correctness + paddingLen, paddingGood := examinePadding(body) + if paddingGood != 255 { + return nil, errInvalidMAC + } + + macSize := mac.Size() + if len(body) < macSize { + return nil, errInvalidMAC + } + + dataEnd := len(body) - macSize - paddingLen + + expectedMAC := body[dataEnd : dataEnd+macSize] + var err error + var actualMAC []byte + if h.ContentType == protocol.ContentTypeConnectionID { + actualMAC, err = c.hmacCID(h.Epoch, h.SequenceNumber, h.Version, body[:dataEnd], c.readMac, c.h, h.ConnectionID) + } else { + actualMAC, err = c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, body[:dataEnd], c.readMac, c.h) + } + // Compute Local MAC and compare + if err != nil || !hmac.Equal(actualMAC, expectedMAC) { + return nil, errInvalidMAC + } + + return append(in[:h.Size()], body[:dataEnd]...), nil +} + +func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.ContentType, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash) ([]byte, error) { + h := hmac.New(hf, key) + + msg := make([]byte, 13) + + binary.BigEndian.PutUint16(msg, epoch) + util.PutBigEndianUint48(msg[2:], sequenceNumber) + msg[8] = byte(contentType) + msg[9] = protocolVersion.Major + msg[10] = protocolVersion.Minor + binary.BigEndian.PutUint16(msg[11:], uint16(len(payload))) + + if _, err := h.Write(msg); err != nil { + return nil, err + } + if _, err := h.Write(payload); err != nil { + return nil, err + } + + return h.Sum(nil), nil +} + +// hmacCID calculates a MAC according to +// https://datatracker.ietf.org/doc/html/rfc9146#section-5.1 +func (c *CBC) hmacCID(epoch uint16, sequenceNumber uint64, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash, cid []byte) ([]byte, error) { + // Must unmarshal inner plaintext in orde to perform MAC. + ip := &recordlayer.InnerPlaintext{} + if err := ip.Unmarshal(payload); err != nil { + return nil, err + } + + h := hmac.New(hf, key) + + var msg cryptobyte.Builder + + msg.AddUint64(seqNumPlaceholder) + msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) + msg.AddUint8(uint8(len(cid))) + msg.AddUint8(uint8(protocol.ContentTypeConnectionID)) + msg.AddUint8(protocolVersion.Major) + msg.AddUint8(protocolVersion.Minor) + msg.AddUint16(epoch) + util.AddUint48(&msg, sequenceNumber) + msg.AddBytes(cid) + msg.AddUint16(uint16(len(payload))) + msg.AddBytes(ip.Content) + msg.AddUint8(uint8(ip.RealType)) + msg.AddBytes(make([]byte, ip.Zeros)) + + if _, err := h.Write(msg.BytesOrPanic()); err != nil { + return nil, err + } + if _, err := h.Write(payload); err != nil { + return nil, err + } + + return h.Sum(nil), nil +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/ccm.go b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/ccm.go new file mode 100644 index 0000000000..6fb185d85f --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/ccm.go @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/aes" + "crypto/rand" + "encoding/binary" + "fmt" + + "github.com/pion/dtls/v2/pkg/crypto/ccm" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +// CCMTagLen is the length of Authentication Tag +type CCMTagLen int + +// CCM Enums +const ( + CCMTagLength8 CCMTagLen = 8 + CCMTagLength CCMTagLen = 16 + ccmNonceLength = 12 +) + +// CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +type CCM struct { + localCCM, remoteCCM ccm.CCM + localWriteIV, remoteWriteIV []byte + tagLen CCMTagLen +} + +// NewCCM creates a DTLS GCM Cipher +func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*CCM, error) { + localBlock, err := aes.NewCipher(localKey) + if err != nil { + return nil, err + } + localCCM, err := ccm.NewCCM(localBlock, int(tagLen), ccmNonceLength) + if err != nil { + return nil, err + } + + remoteBlock, err := aes.NewCipher(remoteKey) + if err != nil { + return nil, err + } + remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), ccmNonceLength) + if err != nil { + return nil, err + } + + return &CCM{ + localCCM: localCCM, + localWriteIV: localWriteIV, + remoteCCM: remoteCCM, + remoteWriteIV: remoteWriteIV, + tagLen: tagLen, + }, nil +} + +// Encrypt encrypt a DTLS RecordLayer message +func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] + + nonce := append(append([]byte{}, c.localWriteIV[:4]...), make([]byte, 8)...) + if _, err := rand.Read(nonce[4:]); err != nil { + return nil, err + } + + var additionalData []byte + if pkt.Header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&pkt.Header, len(payload)) + } else { + additionalData = generateAEADAdditionalData(&pkt.Header, len(payload)) + } + encryptedPayload := c.localCCM.Seal(nil, nonce, payload, additionalData) + + encryptedPayload = append(nonce[4:], encryptedPayload...) + raw = append(raw, encryptedPayload...) + + // Update recordLayer size to include explicit nonce + binary.BigEndian.PutUint16(raw[pkt.Header.Size()-2:], uint16(len(raw)-pkt.Header.Size())) + return raw, nil +} + +// Decrypt decrypts a DTLS RecordLayer message +func (c *CCM) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) { + if err := h.Unmarshal(in); err != nil { + return nil, err + } + switch { + case h.ContentType == protocol.ContentTypeChangeCipherSpec: + // Nothing to encrypt with ChangeCipherSpec + return in, nil + case len(in) <= (8 + h.Size()): + return nil, errNotEnoughRoomForNonce + } + + nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[h.Size():h.Size()+8]...) + out := in[h.Size()+8:] + + var additionalData []byte + if h.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&h, len(out)-int(c.tagLen)) + } else { + additionalData = generateAEADAdditionalData(&h, len(out)-int(c.tagLen)) + } + var err error + out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData) + if err != nil { + return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint + } + return append(in[:h.Size()], out...), nil +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/ciphersuite.go b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/ciphersuite.go new file mode 100644 index 0000000000..a3130be123 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/ciphersuite.go @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package ciphersuite provides the crypto operations needed for a DTLS CipherSuite +package ciphersuite + +import ( + "encoding/binary" + "errors" + + "github.com/pion/dtls/v2/internal/util" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "golang.org/x/crypto/cryptobyte" +) + +const ( + // 8 bytes of 0xff. + // https://datatracker.ietf.org/doc/html/rfc9146#name-record-payload-protection + seqNumPlaceholder = 0xffffffffffffffff +) + +var ( + errNotEnoughRoomForNonce = &protocol.InternalError{Err: errors.New("buffer not long enough to contain nonce")} //nolint:goerr113 + errDecryptPacket = &protocol.TemporaryError{Err: errors.New("failed to decrypt packet")} //nolint:goerr113 + errInvalidMAC = &protocol.TemporaryError{Err: errors.New("invalid mac")} //nolint:goerr113 + errFailedToCast = &protocol.FatalError{Err: errors.New("failed to cast")} //nolint:goerr113 +) + +func generateAEADAdditionalData(h *recordlayer.Header, payloadLen int) []byte { + var additionalData [13]byte + + // SequenceNumber MUST be set first + // we only want uint48, clobbering an extra 2 (using uint64, Golang doesn't have uint48) + binary.BigEndian.PutUint64(additionalData[:], h.SequenceNumber) + binary.BigEndian.PutUint16(additionalData[:], h.Epoch) + additionalData[8] = byte(h.ContentType) + additionalData[9] = h.Version.Major + additionalData[10] = h.Version.Minor + binary.BigEndian.PutUint16(additionalData[len(additionalData)-2:], uint16(payloadLen)) + + return additionalData[:] +} + +// generateAEADAdditionalDataCID generates additional data for AEAD ciphers +// according to https://datatracker.ietf.org/doc/html/rfc9146#name-aead-ciphers +func generateAEADAdditionalDataCID(h *recordlayer.Header, payloadLen int) []byte { + var b cryptobyte.Builder + + b.AddUint64(seqNumPlaceholder) + b.AddUint8(uint8(protocol.ContentTypeConnectionID)) + b.AddUint8(uint8(len(h.ConnectionID))) + b.AddUint8(uint8(protocol.ContentTypeConnectionID)) + b.AddUint8(h.Version.Major) + b.AddUint8(h.Version.Minor) + b.AddUint16(h.Epoch) + util.AddUint48(&b, h.SequenceNumber) + b.AddBytes(h.ConnectionID) + b.AddUint16(uint16(payloadLen)) + + return b.BytesOrPanic() +} + +// examinePadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2. +// +// https://github.com/golang/go/blob/039c2081d1178f90a8fa2f4e6958693129f8de33/src/crypto/tls/conn.go#L245 +func examinePadding(payload []byte) (toRemove int, good byte) { + if len(payload) < 1 { + return 0, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good = byte(int32(^t) >> 31) + + // The maximum possible padding length plus the actual length field + toCheck := 256 + // The length of the padded data is public, so we can use an if here + if toCheck > len(payload) { + toCheck = len(payload) + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + toRemove = int(paddingLen) + 1 + + return toRemove, good +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/gcm.go b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/gcm.go new file mode 100644 index 0000000000..1d09c8eb95 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/crypto/ciphersuite/gcm.go @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package ciphersuite + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "fmt" + + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" +) + +const ( + gcmTagLength = 16 + gcmNonceLength = 12 +) + +// GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets +type GCM struct { + localGCM, remoteGCM cipher.AEAD + localWriteIV, remoteWriteIV []byte +} + +// NewGCM creates a DTLS GCM Cipher +func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, error) { + localBlock, err := aes.NewCipher(localKey) + if err != nil { + return nil, err + } + localGCM, err := cipher.NewGCM(localBlock) + if err != nil { + return nil, err + } + + remoteBlock, err := aes.NewCipher(remoteKey) + if err != nil { + return nil, err + } + remoteGCM, err := cipher.NewGCM(remoteBlock) + if err != nil { + return nil, err + } + + return &GCM{ + localGCM: localGCM, + localWriteIV: localWriteIV, + remoteGCM: remoteGCM, + remoteWriteIV: remoteWriteIV, + }, nil +} + +// Encrypt encrypt a DTLS RecordLayer message +func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) { + payload := raw[pkt.Header.Size():] + raw = raw[:pkt.Header.Size()] + + nonce := make([]byte, gcmNonceLength) + copy(nonce, g.localWriteIV[:4]) + if _, err := rand.Read(nonce[4:]); err != nil { + return nil, err + } + + var additionalData []byte + if pkt.Header.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&pkt.Header, len(payload)) + } else { + additionalData = generateAEADAdditionalData(&pkt.Header, len(payload)) + } + encryptedPayload := g.localGCM.Seal(nil, nonce, payload, additionalData) + r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload)) + copy(r, raw) + copy(r[len(raw):], nonce[4:]) + copy(r[len(raw)+len(nonce[4:]):], encryptedPayload) + + // Update recordLayer size to include explicit nonce + binary.BigEndian.PutUint16(r[pkt.Header.Size()-2:], uint16(len(r)-pkt.Header.Size())) + return r, nil +} + +// Decrypt decrypts a DTLS RecordLayer message +func (g *GCM) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) { + err := h.Unmarshal(in) + switch { + case err != nil: + return nil, err + case h.ContentType == protocol.ContentTypeChangeCipherSpec: + // Nothing to encrypt with ChangeCipherSpec + return in, nil + case len(in) <= (8 + h.Size()): + return nil, errNotEnoughRoomForNonce + } + + nonce := make([]byte, 0, gcmNonceLength) + nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[h.Size():h.Size()+8]...) + out := in[h.Size()+8:] + + var additionalData []byte + if h.ContentType == protocol.ContentTypeConnectionID { + additionalData = generateAEADAdditionalDataCID(&h, len(out)-gcmTagLength) + } else { + additionalData = generateAEADAdditionalData(&h, len(out)-gcmTagLength) + } + out, err = g.remoteGCM.Open(out[:0], nonce, out, additionalData) + if err != nil { + return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint + } + return append(in[:h.Size()], out...), nil +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/net/net.go b/vendor/github.com/pion/dtls/v2/pkg/net/net.go new file mode 100644 index 0000000000..e76daf56aa --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/net/net.go @@ -0,0 +1,108 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package net defines packet-oriented primitives that are compatible with net +// in the standard library. +package net + +import ( + "net" + "time" +) + +// A PacketListener is the same as net.Listener but returns a net.PacketConn on +// Accept() rather than a net.Conn. +// +// Multiple goroutines may invoke methods on a PacketListener simultaneously. +type PacketListener interface { + // Accept waits for and returns the next connection to the listener. + Accept() (net.PacketConn, net.Addr, error) + + // Close closes the listener. + // Any blocked Accept operations will be unblocked and return errors. + Close() error + + // Addr returns the listener's network address. + Addr() net.Addr +} + +// PacketListenerFromListener converts a net.Listener into a +// dtlsnet.PacketListener. +func PacketListenerFromListener(l net.Listener) PacketListener { + return &packetListenerWrapper{ + l: l, + } +} + +// packetListenerWrapper wraps a net.Listener and implements +// dtlsnet.PacketListener. +type packetListenerWrapper struct { + l net.Listener +} + +// Accept calls Accept on the underlying net.Listener and converts the returned +// net.Conn into a net.PacketConn. +func (p *packetListenerWrapper) Accept() (net.PacketConn, net.Addr, error) { + c, err := p.l.Accept() + if err != nil { + return PacketConnFromConn(c), nil, err + } + return PacketConnFromConn(c), c.RemoteAddr(), nil +} + +// Close closes the underlying net.Listener. +func (p *packetListenerWrapper) Close() error { + return p.l.Close() +} + +// Addr returns the address of the underlying net.Listener. +func (p *packetListenerWrapper) Addr() net.Addr { + return p.l.Addr() +} + +// PacketConnFromConn converts a net.Conn into a net.PacketConn. +func PacketConnFromConn(conn net.Conn) net.PacketConn { + return &packetConnWrapper{conn} +} + +// packetConnWrapper wraps a net.Conn and implements net.PacketConn. +type packetConnWrapper struct { + conn net.Conn +} + +// ReadFrom reads from the underlying net.Conn and returns its remote address. +func (p *packetConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := p.conn.Read(b) + return n, p.conn.RemoteAddr(), err +} + +// WriteTo writes to the underlying net.Conn. +func (p *packetConnWrapper) WriteTo(b []byte, _ net.Addr) (int, error) { + n, err := p.conn.Write(b) + return n, err +} + +// Close closes the underlying net.Conn. +func (p *packetConnWrapper) Close() error { + return p.conn.Close() +} + +// LocalAddr returns the local address of the underlying net.Conn. +func (p *packetConnWrapper) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// SetDeadline sets the deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetDeadline(t time.Time) error { + return p.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetReadDeadline(t time.Time) error { + return p.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying net.Conn. +func (p *packetConnWrapper) SetWriteDeadline(t time.Time) error { + return p.conn.SetWriteDeadline(t) +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/content.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/content.go new file mode 100644 index 0000000000..154005e2c3 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/content.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package protocol + +// ContentType represents the IANA Registered ContentTypes +// +// https://tools.ietf.org/html/rfc4346#section-6.2.1 +type ContentType uint8 + +// ContentType enums +const ( + ContentTypeChangeCipherSpec ContentType = 20 + ContentTypeAlert ContentType = 21 + ContentTypeHandshake ContentType = 22 + ContentTypeApplicationData ContentType = 23 + ContentTypeConnectionID ContentType = 25 +) + +// Content is the top level distinguisher for a DTLS Datagram +type Content interface { + ContentType() ContentType + Marshal() ([]byte, error) + Unmarshal(data []byte) error +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/connection_id.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/connection_id.go new file mode 100644 index 0000000000..b3fe1640f2 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/connection_id.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "golang.org/x/crypto/cryptobyte" +) + +// ConnectionID is a DTLS extension that provides an alternative to IP address +// and port for session association. +// +// https://tools.ietf.org/html/rfc9146 +type ConnectionID struct { + // A zero-length connection ID indicates for a client or server that + // negotiated connection IDs from the peer will be sent but there is no need + // to respond with one + CID []byte // variable length +} + +// TypeValue returns the extension TypeValue +func (c ConnectionID) TypeValue() TypeValue { + return ConnectionIDTypeValue +} + +// Marshal encodes the extension +func (c *ConnectionID) Marshal() ([]byte, error) { + var b cryptobyte.Builder + b.AddUint16(uint16(c.TypeValue())) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(c.CID) + }) + }) + return b.Bytes() +} + +// Unmarshal populates the extension from encoded data +func (c *ConnectionID) Unmarshal(data []byte) error { + val := cryptobyte.String(data) + var extension uint16 + val.ReadUint16(&extension) + if TypeValue(extension) != c.TypeValue() { + return errInvalidExtensionType + } + + var extData cryptobyte.String + val.ReadUint16LengthPrefixed(&extData) + + var cid cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&cid) { + return errInvalidCIDFormat + } + c.CID = make([]byte, len(cid)) + if !cid.CopyBytes(c.CID) { + return errInvalidCIDFormat + } + return nil +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/errors.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/errors.go new file mode 100644 index 0000000000..39431206f5 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/errors.go @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "errors" + + "github.com/pion/dtls/v2/pkg/protocol" +) + +var ( + // ErrALPNInvalidFormat is raised when the ALPN format is invalid + ErrALPNInvalidFormat = &protocol.FatalError{Err: errors.New("invalid alpn format")} //nolint:goerr113 + errALPNNoAppProto = &protocol.FatalError{Err: errors.New("no application protocol")} //nolint:goerr113 + errBufferTooSmall = &protocol.TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 + errInvalidExtensionType = &protocol.FatalError{Err: errors.New("invalid extension type")} //nolint:goerr113 + errInvalidSNIFormat = &protocol.FatalError{Err: errors.New("invalid server name format")} //nolint:goerr113 + errInvalidCIDFormat = &protocol.FatalError{Err: errors.New("invalid connection ID format")} //nolint:goerr113 + errLengthMismatch = &protocol.InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 +) diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/extension.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/extension.go new file mode 100644 index 0000000000..e4df859f89 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/extension.go @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package extension implements the extension values in the ClientHello/ServerHello +package extension + +import "encoding/binary" + +// TypeValue is the 2 byte value for a TLS Extension as registered in the IANA +// +// https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml +type TypeValue uint16 + +// TypeValue constants +const ( + ServerNameTypeValue TypeValue = 0 + SupportedEllipticCurvesTypeValue TypeValue = 10 + SupportedPointFormatsTypeValue TypeValue = 11 + SupportedSignatureAlgorithmsTypeValue TypeValue = 13 + UseSRTPTypeValue TypeValue = 14 + ALPNTypeValue TypeValue = 16 + UseExtendedMasterSecretTypeValue TypeValue = 23 + ConnectionIDTypeValue TypeValue = 54 + RenegotiationInfoTypeValue TypeValue = 65281 +) + +// Extension represents a single TLS extension +type Extension interface { + Marshal() ([]byte, error) + Unmarshal(data []byte) error + TypeValue() TypeValue +} + +// Unmarshal many extensions at once +func Unmarshal(buf []byte) ([]Extension, error) { + switch { + case len(buf) == 0: + return []Extension{}, nil + case len(buf) < 2: + return nil, errBufferTooSmall + } + + declaredLen := binary.BigEndian.Uint16(buf) + if len(buf)-2 != int(declaredLen) { + return nil, errLengthMismatch + } + + extensions := []Extension{} + unmarshalAndAppend := func(data []byte, e Extension) error { + err := e.Unmarshal(data) + if err != nil { + return err + } + extensions = append(extensions, e) + return nil + } + + for offset := 2; offset < len(buf); { + if len(buf) < (offset + 2) { + return nil, errBufferTooSmall + } + var err error + switch TypeValue(binary.BigEndian.Uint16(buf[offset:])) { + case ServerNameTypeValue: + err = unmarshalAndAppend(buf[offset:], &ServerName{}) + case SupportedEllipticCurvesTypeValue: + err = unmarshalAndAppend(buf[offset:], &SupportedEllipticCurves{}) + case SupportedPointFormatsTypeValue: + err = unmarshalAndAppend(buf[offset:], &SupportedPointFormats{}) + case SupportedSignatureAlgorithmsTypeValue: + err = unmarshalAndAppend(buf[offset:], &SupportedSignatureAlgorithms{}) + case UseSRTPTypeValue: + err = unmarshalAndAppend(buf[offset:], &UseSRTP{}) + case ALPNTypeValue: + err = unmarshalAndAppend(buf[offset:], &ALPN{}) + case UseExtendedMasterSecretTypeValue: + err = unmarshalAndAppend(buf[offset:], &UseExtendedMasterSecret{}) + case RenegotiationInfoTypeValue: + err = unmarshalAndAppend(buf[offset:], &RenegotiationInfo{}) + case ConnectionIDTypeValue: + err = unmarshalAndAppend(buf[offset:], &ConnectionID{}) + default: + } + if err != nil { + return nil, err + } + if len(buf) < (offset + 4) { + return nil, errBufferTooSmall + } + extensionLength := binary.BigEndian.Uint16(buf[offset+2:]) + offset += (4 + int(extensionLength)) + } + return extensions, nil +} + +// Marshal many extensions at once +func Marshal(e []Extension) ([]byte, error) { + extensions := []byte{} + for _, e := range e { + raw, err := e.Marshal() + if err != nil { + return nil, err + } + extensions = append(extensions, raw...) + } + out := []byte{0x00, 0x00} + binary.BigEndian.PutUint16(out, uint16(len(extensions))) + return append(out, extensions...), nil +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/supported_point_formats.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/supported_point_formats.go new file mode 100644 index 0000000000..5ed0f347f2 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/extension/supported_point_formats.go @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package extension + +import ( + "encoding/binary" + + "github.com/pion/dtls/v2/pkg/crypto/elliptic" +) + +const ( + supportedPointFormatsSize = 5 +) + +// SupportedPointFormats allows a Client/Server to negotiate +// the EllipticCurvePointFormats +// +// https://tools.ietf.org/html/rfc4492#section-5.1.2 +type SupportedPointFormats struct { + PointFormats []elliptic.CurvePointFormat +} + +// TypeValue returns the extension TypeValue +func (s SupportedPointFormats) TypeValue() TypeValue { + return SupportedPointFormatsTypeValue +} + +// Marshal encodes the extension +func (s *SupportedPointFormats) Marshal() ([]byte, error) { + out := make([]byte, supportedPointFormatsSize) + + binary.BigEndian.PutUint16(out, uint16(s.TypeValue())) + binary.BigEndian.PutUint16(out[2:], uint16(1+(len(s.PointFormats)))) + out[4] = byte(len(s.PointFormats)) + + for _, v := range s.PointFormats { + out = append(out, byte(v)) + } + return out, nil +} + +// Unmarshal populates the extension from encoded data +func (s *SupportedPointFormats) Unmarshal(data []byte) error { + if len(data) <= supportedPointFormatsSize { + return errBufferTooSmall + } + + if TypeValue(binary.BigEndian.Uint16(data)) != s.TypeValue() { + return errInvalidExtensionType + } + + pointFormatCount := int(data[4]) + if supportedPointFormatsSize+pointFormatCount > len(data) { + return errLengthMismatch + } + + for i := 0; i < pointFormatCount; i++ { + p := elliptic.CurvePointFormat(data[supportedPointFormatsSize+i]) + switch p { + case elliptic.CurvePointFormatUncompressed: + s.PointFormats = append(s.PointFormats, p) + default: + } + } + return nil +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/header.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/header.go new file mode 100644 index 0000000000..66a1be810f --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/header.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package recordlayer + +import ( + "encoding/binary" + + "github.com/pion/dtls/v2/internal/util" + "github.com/pion/dtls/v2/pkg/protocol" +) + +// Header implements a TLS RecordLayer header +type Header struct { + ContentType protocol.ContentType + ContentLen uint16 + Version protocol.Version + Epoch uint16 + SequenceNumber uint64 // uint48 in spec + + // Optional Fields + ConnectionID []byte +} + +// RecordLayer enums +const ( + // FixedHeaderSize is the size of a DTLS record header when connection IDs + // are not in use. + FixedHeaderSize = 13 + MaxSequenceNumber = 0x0000FFFFFFFFFFFF +) + +// Marshal encodes a TLS RecordLayer Header to binary +func (h *Header) Marshal() ([]byte, error) { + if h.SequenceNumber > MaxSequenceNumber { + return nil, errSequenceNumberOverflow + } + + hs := FixedHeaderSize + len(h.ConnectionID) + + out := make([]byte, hs) + out[0] = byte(h.ContentType) + out[1] = h.Version.Major + out[2] = h.Version.Minor + binary.BigEndian.PutUint16(out[3:], h.Epoch) + util.PutBigEndianUint48(out[5:], h.SequenceNumber) + copy(out[11:11+len(h.ConnectionID)], h.ConnectionID) + binary.BigEndian.PutUint16(out[hs-2:], h.ContentLen) + return out, nil +} + +// Unmarshal populates a TLS RecordLayer Header from binary +func (h *Header) Unmarshal(data []byte) error { + if len(data) < FixedHeaderSize { + return errBufferTooSmall + } + h.ContentType = protocol.ContentType(data[0]) + if h.ContentType == protocol.ContentTypeConnectionID { + // If a CID was expected the ConnectionID should have been initialized. + if len(data) < FixedHeaderSize+len(h.ConnectionID) { + return errBufferTooSmall + } + h.ConnectionID = data[11 : 11+len(h.ConnectionID)] + } + + h.Version.Major = data[1] + h.Version.Minor = data[2] + h.Epoch = binary.BigEndian.Uint16(data[3:]) + + // SequenceNumber is stored as uint48, make into uint64 + seqCopy := make([]byte, 8) + copy(seqCopy[2:], data[5:11]) + h.SequenceNumber = binary.BigEndian.Uint64(seqCopy) + + if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { + return errUnsupportedProtocolVersion + } + + return nil +} + +// Size returns the total size of the header. +func (h *Header) Size() int { + return FixedHeaderSize + len(h.ConnectionID) +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/inner_plaintext.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/inner_plaintext.go new file mode 100644 index 0000000000..bbc94dd80d --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/inner_plaintext.go @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package recordlayer + +import ( + "github.com/pion/dtls/v2/pkg/protocol" + "golang.org/x/crypto/cryptobyte" +) + +// InnerPlaintext implements DTLSInnerPlaintext +// +// https://datatracker.ietf.org/doc/html/rfc9146#name-record-layer-extensions +type InnerPlaintext struct { + Content []byte + RealType protocol.ContentType + Zeros uint +} + +// Marshal encodes a DTLS InnerPlaintext to binary +func (p *InnerPlaintext) Marshal() ([]byte, error) { + var out cryptobyte.Builder + out.AddBytes(p.Content) + out.AddUint8(uint8(p.RealType)) + out.AddBytes(make([]byte, p.Zeros)) + return out.Bytes() +} + +// Unmarshal populates a DTLS InnerPlaintext from binary +func (p *InnerPlaintext) Unmarshal(data []byte) error { + // Process in reverse + i := len(data) - 1 + for i >= 0 { + if data[i] != 0 { + p.Zeros = uint(len(data) - 1 - i) + break + } + i-- + } + if i == 0 { + return errBufferTooSmall + } + p.RealType = protocol.ContentType(data[i]) + p.Content = append([]byte{}, data[:i]...) + + return nil +} diff --git a/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/recordlayer.go b/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/recordlayer.go new file mode 100644 index 0000000000..213a7976ad --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/pkg/protocol/recordlayer/recordlayer.go @@ -0,0 +1,145 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package recordlayer + +import ( + "encoding/binary" + + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/alert" + "github.com/pion/dtls/v2/pkg/protocol/handshake" +) + +// DTLS fixed size record layer header when Connection IDs are not in-use. + +// --------------------------------- +// | Type | Version | Epoch | +// --------------------------------- +// | Epoch | Sequence Number | +// --------------------------------- +// | Sequence Number | Length | +// --------------------------------- +// | Length | Fragment... | +// --------------------------------- + +// fixedHeaderLenIdx is the index at which the record layer content length is +// specified in a fixed length header (i.e. one that does not include a +// Connection ID). +const fixedHeaderLenIdx = 11 + +// RecordLayer which handles all data transport. +// The record layer is assumed to sit directly on top of some +// reliable transport such as TCP. The record layer can carry four types of content: +// +// 1. Handshake messages—used for algorithm negotiation and key establishment. +// 2. ChangeCipherSpec messages—really part of the handshake but technically a separate kind of message. +// 3. Alert messages—used to signal that errors have occurred +// 4. Application layer data +// +// The DTLS record layer is extremely similar to that of TLS 1.1. The +// only change is the inclusion of an explicit sequence number in the +// record. This sequence number allows the recipient to correctly +// verify the TLS MAC. +// +// https://tools.ietf.org/html/rfc4347#section-4.1 +type RecordLayer struct { + Header Header + Content protocol.Content +} + +// Marshal encodes the RecordLayer to binary +func (r *RecordLayer) Marshal() ([]byte, error) { + contentRaw, err := r.Content.Marshal() + if err != nil { + return nil, err + } + + r.Header.ContentLen = uint16(len(contentRaw)) + r.Header.ContentType = r.Content.ContentType() + + headerRaw, err := r.Header.Marshal() + if err != nil { + return nil, err + } + + return append(headerRaw, contentRaw...), nil +} + +// Unmarshal populates the RecordLayer from binary +func (r *RecordLayer) Unmarshal(data []byte) error { + if err := r.Header.Unmarshal(data); err != nil { + return err + } + + switch r.Header.ContentType { + case protocol.ContentTypeChangeCipherSpec: + r.Content = &protocol.ChangeCipherSpec{} + case protocol.ContentTypeAlert: + r.Content = &alert.Alert{} + case protocol.ContentTypeHandshake: + r.Content = &handshake.Handshake{} + case protocol.ContentTypeApplicationData: + r.Content = &protocol.ApplicationData{} + default: + return errInvalidContentType + } + + return r.Content.Unmarshal(data[r.Header.Size()+len(r.Header.ConnectionID):]) +} + +// UnpackDatagram extracts all RecordLayer messages from a single datagram. +// Note that as with TLS, multiple handshake messages may be placed in +// the same DTLS record, provided that there is room and that they are +// part of the same flight. Thus, there are two acceptable ways to pack +// two DTLS messages into the same datagram: in the same record or in +// separate records. +// https://tools.ietf.org/html/rfc6347#section-4.2.3 +func UnpackDatagram(buf []byte) ([][]byte, error) { + out := [][]byte{} + + for offset := 0; len(buf) != offset; { + if len(buf)-offset <= FixedHeaderSize { + return nil, errInvalidPacketLength + } + + pktLen := (FixedHeaderSize + int(binary.BigEndian.Uint16(buf[offset+11:]))) + if offset+pktLen > len(buf) { + return nil, errInvalidPacketLength + } + + out = append(out, buf[offset:offset+pktLen]) + offset += pktLen + } + + return out, nil +} + +// ContentAwareUnpackDatagram is the same as UnpackDatagram but considers the +// presence of a connection identifier if the record is of content type +// tls12_cid. +func ContentAwareUnpackDatagram(buf []byte, cidLength int) ([][]byte, error) { + out := [][]byte{} + + for offset := 0; len(buf) != offset; { + headerSize := FixedHeaderSize + lenIdx := fixedHeaderLenIdx + if protocol.ContentType(buf[offset]) == protocol.ContentTypeConnectionID { + headerSize += cidLength + lenIdx += cidLength + } + if len(buf)-offset <= headerSize { + return nil, errInvalidPacketLength + } + + pktLen := (headerSize + int(binary.BigEndian.Uint16(buf[offset+lenIdx:]))) + if offset+pktLen > len(buf) { + return nil, errInvalidPacketLength + } + + out = append(out, buf[offset:offset+pktLen]) + offset += pktLen + } + + return out, nil +} diff --git a/vendor/github.com/pion/dtls/v2/resume.go b/vendor/github.com/pion/dtls/v2/resume.go new file mode 100644 index 0000000000..9e8a2ae42a --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/resume.go @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "context" + "net" +) + +// Resume imports an already established dtls connection using a specific dtls state +func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { + if err := state.initCipherSuite(); err != nil { + return nil, err + } + c, err := createConn(context.Background(), conn, rAddr, config, state.isClient, state) + if err != nil { + return nil, err + } + + return c, nil +} diff --git a/vendor/github.com/pion/dtls/v2/state.go b/vendor/github.com/pion/dtls/v2/state.go new file mode 100644 index 0000000000..cd69455735 --- /dev/null +++ b/vendor/github.com/pion/dtls/v2/state.go @@ -0,0 +1,240 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package dtls + +import ( + "bytes" + "encoding/gob" + "sync/atomic" + + "github.com/pion/dtls/v2/pkg/crypto/elliptic" + "github.com/pion/dtls/v2/pkg/crypto/prf" + "github.com/pion/dtls/v2/pkg/protocol/handshake" + "github.com/pion/transport/v3/replaydetector" +) + +// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +type State struct { + localEpoch, remoteEpoch atomic.Value + localSequenceNumber []uint64 // uint48 + localRandom, remoteRandom handshake.Random + masterSecret []byte + cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen + + srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile + PeerCertificates [][]byte + IdentityHint []byte + SessionID []byte + + // Connection Identifiers must be negotiated afresh on session resumption. + // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension + + // localConnectionID is the locally generated connection ID that is expected + // to be received from the remote endpoint. + // For a server, this is the connection ID sent in ServerHello. + // For a client, this is the connection ID sent in the ClientHello. + localConnectionID []byte + // remoteConnectionID is the connection ID that the remote endpoint + // specifies should be sent. + // For a server, this is the connection ID received in the ClientHello. + // For a client, this is the connection ID received in the ServerHello. + remoteConnectionID []byte + + isClient bool + + preMasterSecret []byte + extendedMasterSecret bool + + namedCurve elliptic.Curve + localKeypair *elliptic.Keypair + cookie []byte + handshakeSendSequence int + handshakeRecvSequence int + serverName string + remoteRequestedCertificate bool // Did we get a CertificateRequest + localCertificatesVerify []byte // cache CertificateVerify + localVerifyData []byte // cached VerifyData + localKeySignature []byte // cached keySignature + peerCertificatesVerified bool + + replayDetector []replaydetector.ReplayDetector + + peerSupportedProtocols []string + NegotiatedProtocol string +} + +type serializedState struct { + LocalEpoch uint16 + RemoteEpoch uint16 + LocalRandom [handshake.RandomLength]byte + RemoteRandom [handshake.RandomLength]byte + CipherSuiteID uint16 + MasterSecret []byte + SequenceNumber uint64 + SRTPProtectionProfile uint16 + PeerCertificates [][]byte + IdentityHint []byte + SessionID []byte + LocalConnectionID []byte + RemoteConnectionID []byte + IsClient bool +} + +func (s *State) clone() *State { + serialized := s.serialize() + state := &State{} + state.deserialize(*serialized) + + return state +} + +func (s *State) serialize() *serializedState { + // Marshal random values + localRnd := s.localRandom.MarshalFixed() + remoteRnd := s.remoteRandom.MarshalFixed() + + epoch := s.getLocalEpoch() + return &serializedState{ + LocalEpoch: s.getLocalEpoch(), + RemoteEpoch: s.getRemoteEpoch(), + CipherSuiteID: uint16(s.cipherSuite.ID()), + MasterSecret: s.masterSecret, + SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]), + LocalRandom: localRnd, + RemoteRandom: remoteRnd, + SRTPProtectionProfile: uint16(s.srtpProtectionProfile), + PeerCertificates: s.PeerCertificates, + IdentityHint: s.IdentityHint, + SessionID: s.SessionID, + LocalConnectionID: s.localConnectionID, + RemoteConnectionID: s.remoteConnectionID, + IsClient: s.isClient, + } +} + +func (s *State) deserialize(serialized serializedState) { + // Set epoch values + epoch := serialized.LocalEpoch + s.localEpoch.Store(serialized.LocalEpoch) + s.remoteEpoch.Store(serialized.RemoteEpoch) + + for len(s.localSequenceNumber) <= int(epoch) { + s.localSequenceNumber = append(s.localSequenceNumber, uint64(0)) + } + + // Set random values + localRandom := &handshake.Random{} + localRandom.UnmarshalFixed(serialized.LocalRandom) + s.localRandom = *localRandom + + remoteRandom := &handshake.Random{} + remoteRandom.UnmarshalFixed(serialized.RemoteRandom) + s.remoteRandom = *remoteRandom + + s.isClient = serialized.IsClient + + // Set master secret + s.masterSecret = serialized.MasterSecret + + // Set cipher suite + s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil) + + atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber) + s.srtpProtectionProfile = SRTPProtectionProfile(serialized.SRTPProtectionProfile) + + // Set remote certificate + s.PeerCertificates = serialized.PeerCertificates + + s.IdentityHint = serialized.IdentityHint + + // Set local and remote connection IDs + s.localConnectionID = serialized.LocalConnectionID + s.remoteConnectionID = serialized.RemoteConnectionID + + s.SessionID = serialized.SessionID +} + +func (s *State) initCipherSuite() error { + if s.cipherSuite.IsInitialized() { + return nil + } + + localRandom := s.localRandom.MarshalFixed() + remoteRandom := s.remoteRandom.MarshalFixed() + + var err error + if s.isClient { + err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true) + } else { + err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false) + } + if err != nil { + return err + } + return nil +} + +// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation +func (s *State) MarshalBinary() ([]byte, error) { + serialized := s.serialize() + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(*serialized); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation +func (s *State) UnmarshalBinary(data []byte) error { + enc := gob.NewDecoder(bytes.NewBuffer(data)) + var serialized serializedState + if err := enc.Decode(&serialized); err != nil { + return err + } + + s.deserialize(serialized) + + return s.initCipherSuite() +} + +// ExportKeyingMaterial returns length bytes of exported key material in a new +// slice as defined in RFC 5705. +// This allows protocols to use DTLS for key establishment, but +// then use some of the keying material for their own purposes +func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { + if s.getLocalEpoch() == 0 { + return nil, errHandshakeInProgress + } else if len(context) != 0 { + return nil, errContextUnsupported + } else if _, ok := invalidKeyingLabels()[label]; ok { + return nil, errReservedExportKeyingMaterial + } + + localRandom := s.localRandom.MarshalFixed() + remoteRandom := s.remoteRandom.MarshalFixed() + + seed := []byte(label) + if s.isClient { + seed = append(append(seed, localRandom[:]...), remoteRandom[:]...) + } else { + seed = append(append(seed, remoteRandom[:]...), localRandom[:]...) + } + return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc()) +} + +func (s *State) getRemoteEpoch() uint16 { + if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok { + return remoteEpoch + } + return 0 +} + +func (s *State) getLocalEpoch() uint16 { + if localEpoch, ok := s.localEpoch.Load().(uint16); ok { + return localEpoch + } + return 0 +} diff --git a/vendor/github.com/pion/transport/v3/AUTHORS.txt b/vendor/github.com/pion/transport/v3/AUTHORS.txt new file mode 100644 index 0000000000..278ce646f7 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/AUTHORS.txt @@ -0,0 +1,31 @@ +# Thank you to everyone that made Pion possible. If you are interested in contributing +# we would love to have you https://github.com/pion/webrtc/wiki/Contributing +# +# This file is auto generated, using git to list all individuals contributors. +# see https://github.com/pion/.goassets/blob/master/scripts/generate-authors.sh for the scripting +Adrian Cable +Atsushi Watanabe +backkem +cnderrauber +Daniel +Daniel Mangum +Hugo Arregui +Jeremiah Millay +Jozef Kralik +Juliusz Chroboczek +Luke Curley +Mathis Engelbart +OrlandoCo +Sean DuBois +Sean DuBois +Sean DuBois +Sean DuBois +Sean DuBois +Steffen Vogel +Winlin +Woodrow Douglass +Yutaka Takeda +ZHENK + +# List of contributors not appearing in Git history + diff --git a/vendor/github.com/pion/transport/v3/LICENSE b/vendor/github.com/pion/transport/v3/LICENSE new file mode 100644 index 0000000000..491caf6b0f --- /dev/null +++ b/vendor/github.com/pion/transport/v3/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2023 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pion/transport/v3/deadline/deadline.go b/vendor/github.com/pion/transport/v3/deadline/deadline.go new file mode 100644 index 0000000000..abd39f06de --- /dev/null +++ b/vendor/github.com/pion/transport/v3/deadline/deadline.go @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package deadline provides deadline timer used to implement +// net.Conn compatible connection +package deadline + +import ( + "context" + "sync" + "time" +) + +// Deadline signals updatable deadline timer. +// Also, it implements context.Context. +type Deadline struct { + exceeded chan struct{} + stop chan struct{} + stopped chan bool + deadline time.Time + mu sync.RWMutex +} + +// New creates new deadline timer. +func New() *Deadline { + d := &Deadline{ + exceeded: make(chan struct{}), + stop: make(chan struct{}), + stopped: make(chan bool, 1), + } + d.stopped <- true + return d +} + +// Set new deadline. Zero value means no deadline. +func (d *Deadline) Set(t time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + d.deadline = t + + close(d.stop) + + select { + case <-d.exceeded: + d.exceeded = make(chan struct{}) + default: + stopped := <-d.stopped + if !stopped { + d.exceeded = make(chan struct{}) + } + } + d.stop = make(chan struct{}) + d.stopped = make(chan bool, 1) + + if t.IsZero() { + d.stopped <- true + return + } + + if dur := time.Until(t); dur > 0 { + exceeded := d.exceeded + stopped := d.stopped + go func() { + timer := time.NewTimer(dur) + select { + case <-timer.C: + close(exceeded) + stopped <- false + case <-d.stop: + if !timer.Stop() { + <-timer.C + } + stopped <- true + } + }() + return + } + + close(d.exceeded) + d.stopped <- false +} + +// Done receives deadline signal. +func (d *Deadline) Done() <-chan struct{} { + d.mu.RLock() + defer d.mu.RUnlock() + return d.exceeded +} + +// Err returns context.DeadlineExceeded if the deadline is exceeded. +// Otherwise, it returns nil. +func (d *Deadline) Err() error { + d.mu.RLock() + defer d.mu.RUnlock() + select { + case <-d.exceeded: + return context.DeadlineExceeded + default: + return nil + } +} + +// Deadline returns current deadline. +func (d *Deadline) Deadline() (time.Time, bool) { + d.mu.RLock() + defer d.mu.RUnlock() + if d.deadline.IsZero() { + return d.deadline, false + } + return d.deadline, true +} + +// Value returns nil. +func (d *Deadline) Value(interface{}) interface{} { + return nil +} diff --git a/vendor/github.com/pion/transport/v3/netctx/conn.go b/vendor/github.com/pion/transport/v3/netctx/conn.go new file mode 100644 index 0000000000..823107c7e0 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/netctx/conn.go @@ -0,0 +1,185 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package netctx wraps common net interfaces using context.Context. +package netctx + +import ( + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// ErrClosing is returned on Write to closed connection. +var ErrClosing = errors.New("use of closed network connection") + +// Reader is an interface for context controlled reader. +type Reader interface { + ReadContext(context.Context, []byte) (int, error) +} + +// Writer is an interface for context controlled writer. +type Writer interface { + WriteContext(context.Context, []byte) (int, error) +} + +// ReadWriter is a composite of ReadWriter. +type ReadWriter interface { + Reader + Writer +} + +// Conn is a wrapper of net.Conn using context.Context. +type Conn interface { + Reader + Writer + io.Closer + LocalAddr() net.Addr + RemoteAddr() net.Addr + Conn() net.Conn +} + +type conn struct { + nextConn net.Conn + closed chan struct{} + closeOnce sync.Once + readMu sync.Mutex + writeMu sync.Mutex +} + +var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals + +// NewConn creates a new Conn wrapping given net.Conn. +func NewConn(netConn net.Conn) Conn { + c := &conn{ + nextConn: netConn, + closed: make(chan struct{}), + } + return c +} + +// ReadContext reads data from the connection. +// Unlike net.Conn.Read(), the provided context is used to control timeout. +func (c *conn) ReadContext(ctx context.Context, b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + select { + case <-c.closed: + return 0, net.ErrClosed + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := c.nextConn.SetReadDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + return + } + <-done + if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, err := c.nextConn.Read(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + return n, err +} + +// WriteContext writes data to the connection. +// Unlike net.Conn.Write(), the provided context is used to control timeout. +func (c *conn) WriteContext(ctx context.Context, b []byte) (int, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + select { + case <-c.closed: + return 0, ErrClosing + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + return + } + <-done + if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, err := c.nextConn.Write(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + return n, err +} + +// Close closes the connection. +// Any blocked ReadContext or WriteContext operations will be unblocked and +// return errors. +func (c *conn) Close() error { + err := c.nextConn.Close() + c.closeOnce.Do(func() { + c.writeMu.Lock() + c.readMu.Lock() + close(c.closed) + c.readMu.Unlock() + c.writeMu.Unlock() + }) + return err +} + +// LocalAddr returns the local network address, if known. +func (c *conn) LocalAddr() net.Addr { + return c.nextConn.LocalAddr() +} + +// LocalAddr returns the local network address, if known. +func (c *conn) RemoteAddr() net.Addr { + return c.nextConn.RemoteAddr() +} + +// Conn returns the underlying net.Conn. +func (c *conn) Conn() net.Conn { + return c.nextConn +} diff --git a/vendor/github.com/pion/transport/v3/netctx/packetconn.go b/vendor/github.com/pion/transport/v3/netctx/packetconn.go new file mode 100644 index 0000000000..a4ce22d956 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/netctx/packetconn.go @@ -0,0 +1,175 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package netctx + +import ( + "context" + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// ReaderFrom is an interface for context controlled packet reader. +type ReaderFrom interface { + ReadFromContext(context.Context, []byte) (int, net.Addr, error) +} + +// WriterTo is an interface for context controlled packet writer. +type WriterTo interface { + WriteToContext(context.Context, []byte, net.Addr) (int, error) +} + +// PacketConn is a wrapper of net.PacketConn using context.Context. +type PacketConn interface { + ReaderFrom + WriterTo + io.Closer + LocalAddr() net.Addr + Conn() net.PacketConn +} + +type packetConn struct { + nextConn net.PacketConn + closed chan struct{} + closeOnce sync.Once + readMu sync.Mutex + writeMu sync.Mutex +} + +// NewPacketConn creates a new PacketConn wrapping the given net.PacketConn. +func NewPacketConn(pconn net.PacketConn) PacketConn { + p := &packetConn{ + nextConn: pconn, + closed: make(chan struct{}), + } + return p +} + +// ReadFromContext reads a packet from the connection, +// copying the payload into p. It returns the number of +// bytes copied into p and the return address that +// was on the packet. +// It returns the number of bytes read (0 <= n <= len(p)) +// and any error encountered. Callers should always process +// the n > 0 bytes returned before considering the error err. +// Unlike net.PacketConn.ReadFrom(), the provided context is +// used to control timeout. +func (p *packetConn) ReadFromContext(ctx context.Context, b []byte) (int, net.Addr, error) { + p.readMu.Lock() + defer p.readMu.Unlock() + + select { + case <-p.closed: + return 0, nil, net.ErrClosed + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := p.nextConn.SetReadDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + return + } + <-done + if err := p.nextConn.SetReadDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, raddr, err := p.nextConn.ReadFrom(b) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + return n, raddr, err +} + +// WriteToContext writes a packet with payload p to addr. +// Unlike net.PacketConn.WriteTo(), the provided context +// is used to control timeout. +// On packet-oriented connections, write timeouts are rare. +func (p *packetConn) WriteToContext(ctx context.Context, b []byte, raddr net.Addr) (int, error) { + p.writeMu.Lock() + defer p.writeMu.Unlock() + + select { + case <-p.closed: + return 0, ErrClosing + default: + } + + done := make(chan struct{}) + var wg sync.WaitGroup + var errSetDeadline atomic.Value + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + // context canceled + if err := p.nextConn.SetWriteDeadline(veryOld); err != nil { + errSetDeadline.Store(err) + return + } + <-done + if err := p.nextConn.SetWriteDeadline(time.Time{}); err != nil { + errSetDeadline.Store(err) + } + case <-done: + } + }() + + n, err := p.nextConn.WriteTo(b, raddr) + + close(done) + wg.Wait() + if e := ctx.Err(); e != nil && n == 0 { + err = e + } + if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { + err = err2 + } + return n, err +} + +// Close closes the connection. +// Any blocked ReadFromContext or WriteToContext operations will be unblocked +// and return errors. +func (p *packetConn) Close() error { + err := p.nextConn.Close() + p.closeOnce.Do(func() { + p.writeMu.Lock() + p.readMu.Lock() + close(p.closed) + p.readMu.Unlock() + p.writeMu.Unlock() + }) + return err +} + +// LocalAddr returns the local network address, if known. +func (p *packetConn) LocalAddr() net.Addr { + return p.nextConn.LocalAddr() +} + +// Conn returns the underlying net.PacketConn. +func (p *packetConn) Conn() net.PacketConn { + return p.nextConn +} diff --git a/vendor/github.com/pion/transport/v3/netctx/pipe.go b/vendor/github.com/pion/transport/v3/netctx/pipe.go new file mode 100644 index 0000000000..7deae668aa --- /dev/null +++ b/vendor/github.com/pion/transport/v3/netctx/pipe.go @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package netctx + +import ( + "net" +) + +// Pipe creates piped pair of Conn. +func Pipe() (Conn, Conn) { + ca, cb := net.Pipe() + return NewConn(ca), NewConn(cb) +} diff --git a/vendor/github.com/pion/transport/v3/packetio/buffer.go b/vendor/github.com/pion/transport/v3/packetio/buffer.go new file mode 100644 index 0000000000..76c72734d6 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/packetio/buffer.go @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package packetio provides packet buffer +package packetio + +import ( + "errors" + "io" + "sync" + "time" + + "github.com/pion/transport/v3/deadline" +) + +var errPacketTooBig = errors.New("packet too big") + +// BufferPacketType allow the Buffer to know which packet protocol is writing. +type BufferPacketType int + +const ( + // RTPBufferPacket indicates the Buffer that is handling RTP packets + RTPBufferPacket BufferPacketType = 1 + // RTCPBufferPacket indicates the Buffer that is handling RTCP packets + RTCPBufferPacket BufferPacketType = 2 +) + +// Buffer allows writing packets to an intermediate buffer, which can then be read form. +// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read. +type Buffer struct { + mutex sync.Mutex + + // this is a circular buffer. If head <= tail, then the useful + // data is in the interval [head, tail[. If tail < head, then + // the useful data is the union of [head, len[ and [0, tail[. + // In order to avoid ambiguity when head = tail, we always leave + // an unused byte in the buffer. + data []byte + head, tail int + + notify chan struct{} // non-nil when we have blocked readers + closed bool + + count int + limitCount, limitSize int + + readDeadline *deadline.Deadline +} + +const ( + minSize = 2048 + cutoffSize = 128 * 1024 + maxSize = 4 * 1024 * 1024 +) + +// NewBuffer creates a new Buffer. +func NewBuffer() *Buffer { + return &Buffer{ + readDeadline: deadline.New(), + } +} + +// available returns true if the buffer is large enough to fit a packet +// of the given size, taking overhead into account. +func (b *Buffer) available(size int) bool { + available := b.head - b.tail + if available <= 0 { + available += len(b.data) + } + // we interpret head=tail as empty, so always keep a byte free + if size+2+1 > available { + return false + } + + return true +} + +// grow increases the size of the buffer. If it returns nil, then the +// buffer has been grown. It returns ErrFull if hits a limit. +func (b *Buffer) grow() error { + var newSize int + if len(b.data) < cutoffSize { + newSize = 2 * len(b.data) + } else { + newSize = 5 * len(b.data) / 4 + } + if newSize < minSize { + newSize = minSize + } + if (b.limitSize <= 0 || sizeHardLimit) && newSize > maxSize { + newSize = maxSize + } + + // one byte slack + if b.limitSize > 0 && newSize > b.limitSize+1 { + newSize = b.limitSize + 1 + } + + if newSize <= len(b.data) { + return ErrFull + } + + newData := make([]byte, newSize) + + var n int + if b.head <= b.tail { + // data was contiguous + n = copy(newData, b.data[b.head:b.tail]) + } else { + // data was discontinuous + n = copy(newData, b.data[b.head:]) + n += copy(newData[n:], b.data[:b.tail]) + } + b.head = 0 + b.tail = n + b.data = newData + + return nil +} + +// Write appends a copy of the packet data to the buffer. +// Returns ErrFull if the packet doesn't fit. +// +// Note that the packet size is limited to 65536 bytes since v0.11.0 due to the internal data structure. +func (b *Buffer) Write(packet []byte) (int, error) { + if len(packet) >= 0x10000 { + return 0, errPacketTooBig + } + + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + return 0, io.ErrClosedPipe + } + + if (b.limitCount > 0 && b.count >= b.limitCount) || + (b.limitSize > 0 && b.size()+2+len(packet) > b.limitSize) { + b.mutex.Unlock() + return 0, ErrFull + } + + // grow the buffer until the packet fits + for !b.available(len(packet)) { + err := b.grow() + if err != nil { + b.mutex.Unlock() + return 0, err + } + } + + var notify chan struct{} + if b.notify != nil { + // Prepare to notify readers, but only + // actually do it after we release the lock. + notify = b.notify + b.notify = nil + } + + // store the length of the packet + b.data[b.tail] = uint8(len(packet) >> 8) + b.tail++ + if b.tail >= len(b.data) { + b.tail = 0 + } + b.data[b.tail] = uint8(len(packet)) + b.tail++ + if b.tail >= len(b.data) { + b.tail = 0 + } + + // store the packet + n := copy(b.data[b.tail:], packet) + b.tail += n + if b.tail >= len(b.data) { + // we reached the end, wrap around + m := copy(b.data, packet[n:]) + b.tail = m + } + b.count++ + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return len(packet), nil +} + +// Read populates the given byte slice, returning the number of bytes read. +// Blocks until data is available or the buffer is closed. +// Returns io.ErrShortBuffer is the packet is too small to copy the Write. +// Returns io.EOF if the buffer is closed. +func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit + // Return immediately if the deadline is already exceeded. + select { + case <-b.readDeadline.Done(): + return 0, &netError{ErrTimeout, true, true} + default: + } + + for { + b.mutex.Lock() + + if b.head != b.tail { + // decode the packet size + n1 := b.data[b.head] + b.head++ + if b.head >= len(b.data) { + b.head = 0 + } + n2 := b.data[b.head] + b.head++ + if b.head >= len(b.data) { + b.head = 0 + } + count := int((uint16(n1) << 8) | uint16(n2)) + + // determine the number of bytes we'll actually copy + copied := count + if copied > len(packet) { + copied = len(packet) + } + + // copy the data + if b.head+copied < len(b.data) { + copy(packet, b.data[b.head:b.head+copied]) + } else { + k := copy(packet, b.data[b.head:]) + copy(packet[k:], b.data[:copied-k]) + } + + // advance head, discarding any data that wasn't copied + b.head += count + if b.head >= len(b.data) { + b.head -= len(b.data) + } + + if b.head == b.tail { + // the buffer is empty, reset to beginning + // in order to improve cache locality. + b.head = 0 + b.tail = 0 + } + + b.count-- + + b.mutex.Unlock() + + if copied < count { + return copied, io.ErrShortBuffer + } + return copied, nil + } + + if b.closed { + b.mutex.Unlock() + return 0, io.EOF + } + + if b.notify == nil { + b.notify = make(chan struct{}) + } + notify := b.notify + b.mutex.Unlock() + + select { + case <-b.readDeadline.Done(): + return 0, &netError{ErrTimeout, true, true} + case <-notify: + } + } +} + +// Close the buffer, unblocking any pending reads. +// Data in the buffer can still be read, Read will return io.EOF only when empty. +func (b *Buffer) Close() (err error) { + b.mutex.Lock() + + if b.closed { + b.mutex.Unlock() + return nil + } + + notify := b.notify + b.notify = nil + b.closed = true + + b.mutex.Unlock() + + if notify != nil { + close(notify) + } + + return nil +} + +// Count returns the number of packets in the buffer. +func (b *Buffer) Count() int { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.count +} + +// SetLimitCount controls the maximum number of packets that can be buffered. +// Causes Write to return ErrFull when this limit is reached. +// A zero value will disable this limit. +func (b *Buffer) SetLimitCount(limit int) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.limitCount = limit +} + +// Size returns the total byte size of packets in the buffer, including +// a small amount of administrative overhead. +func (b *Buffer) Size() int { + b.mutex.Lock() + defer b.mutex.Unlock() + + return b.size() +} + +func (b *Buffer) size() int { + size := b.tail - b.head + if size < 0 { + size += len(b.data) + } + return size +} + +// SetLimitSize controls the maximum number of bytes that can be buffered. +// Causes Write to return ErrFull when this limit is reached. +// A zero value means 4MB since v0.11.0. +// +// User can set packetioSizeHardLimit build tag to enable 4MB hard limit. +// When packetioSizeHardLimit build tag is set, SetLimitSize exceeding +// the hard limit will be silently discarded. +func (b *Buffer) SetLimitSize(limit int) { + b.mutex.Lock() + defer b.mutex.Unlock() + + b.limitSize = limit +} + +// SetReadDeadline sets the deadline for the Read operation. +// Setting to zero means no deadline. +func (b *Buffer) SetReadDeadline(t time.Time) error { + b.readDeadline.Set(t) + return nil +} diff --git a/vendor/github.com/pion/transport/v3/packetio/errors.go b/vendor/github.com/pion/transport/v3/packetio/errors.go new file mode 100644 index 0000000000..4974a10b5c --- /dev/null +++ b/vendor/github.com/pion/transport/v3/packetio/errors.go @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package packetio + +import ( + "errors" +) + +// netError implements net.Error +type netError struct { + error + timeout, temporary bool +} + +func (e *netError) Timeout() bool { + return e.timeout +} + +func (e *netError) Temporary() bool { + return e.temporary +} + +var ( + // ErrFull is returned when the buffer has hit the configured limits. + ErrFull = errors.New("packetio.Buffer is full, discarding write") + + // ErrTimeout is returned when a deadline has expired + ErrTimeout = errors.New("i/o timeout") +) diff --git a/vendor/github.com/pion/transport/v3/packetio/hardlimit.go b/vendor/github.com/pion/transport/v3/packetio/hardlimit.go new file mode 100644 index 0000000000..8058e47fa9 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/packetio/hardlimit.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build packetioSizeHardlimit +// +build packetioSizeHardlimit + +package packetio + +const sizeHardLimit = true diff --git a/vendor/github.com/pion/transport/v3/packetio/no_hardlimit.go b/vendor/github.com/pion/transport/v3/packetio/no_hardlimit.go new file mode 100644 index 0000000000..a59e259577 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/packetio/no_hardlimit.go @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !packetioSizeHardlimit +// +build !packetioSizeHardlimit + +package packetio + +const sizeHardLimit = false diff --git a/vendor/github.com/pion/transport/v3/replaydetector/fixedbig.go b/vendor/github.com/pion/transport/v3/replaydetector/fixedbig.go new file mode 100644 index 0000000000..80cb6b3057 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/replaydetector/fixedbig.go @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package replaydetector + +import ( + "fmt" +) + +// fixedBigInt is the fix-sized multi-word integer. +type fixedBigInt struct { + bits []uint64 + n uint + msbMask uint64 +} + +// newFixedBigInt creates a new fix-sized multi-word int. +func newFixedBigInt(n uint) *fixedBigInt { + chunkSize := (n + 63) / 64 + if chunkSize == 0 { + chunkSize = 1 + } + return &fixedBigInt{ + bits: make([]uint64, chunkSize), + n: n, + msbMask: (1 << (64 - n%64)) - 1, + } +} + +// Lsh is the left shift operation. +func (s *fixedBigInt) Lsh(n uint) { + if n == 0 { + return + } + nChunk := int(n / 64) + nN := n % 64 + + for i := len(s.bits) - 1; i >= 0; i-- { + var carry uint64 + if i-nChunk >= 0 { + carry = s.bits[i-nChunk] << nN + if i-nChunk-1 >= 0 { + carry |= s.bits[i-nChunk-1] >> (64 - nN) + } + } + s.bits[i] = (s.bits[i] << n) | carry + } + s.bits[len(s.bits)-1] &= s.msbMask +} + +// Bit returns i-th bit of the fixedBigInt. +func (s *fixedBigInt) Bit(i uint) uint { + if i >= s.n { + return 0 + } + chunk := i / 64 + pos := i % 64 + if s.bits[chunk]&(1<= s.n { + return + } + chunk := i / 64 + pos := i % 64 + s.bits[chunk] |= 1 << pos +} + +// String returns string representation of fixedBigInt. +func (s *fixedBigInt) String() string { + var out string + for i := len(s.bits) - 1; i >= 0; i-- { + out += fmt.Sprintf("%016X", s.bits[i]) + } + return out +} diff --git a/vendor/github.com/pion/transport/v3/replaydetector/replaydetector.go b/vendor/github.com/pion/transport/v3/replaydetector/replaydetector.go new file mode 100644 index 0000000000..d40799568c --- /dev/null +++ b/vendor/github.com/pion/transport/v3/replaydetector/replaydetector.go @@ -0,0 +1,132 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package replaydetector provides packet replay detection algorithm. +package replaydetector + +// ReplayDetector is the interface of sequence replay detector. +type ReplayDetector interface { + // Check returns true if given sequence number is not replayed. + // Call accept() to mark the packet is received properly. + // The return value of accept() indicates whether the accepted packet is + // has the latest observed sequence number. + Check(seq uint64) (accept func() bool, ok bool) +} + +// nop is a no-op func that is returned in the case that Check() fails. +func nop() bool { + return false +} + +type slidingWindowDetector struct { + latestSeq uint64 + maxSeq uint64 + windowSize uint + mask *fixedBigInt +} + +// New creates ReplayDetector. +// Created ReplayDetector doesn't allow wrapping. +// It can handle monotonically increasing sequence number up to +// full 64bit number. It is suitable for DTLS replay protection. +func New(windowSize uint, maxSeq uint64) ReplayDetector { + return &slidingWindowDetector{ + maxSeq: maxSeq, + windowSize: windowSize, + mask: newFixedBigInt(windowSize), + } +} + +func (d *slidingWindowDetector) Check(seq uint64) (func() bool, bool) { + if seq > d.maxSeq { + // Exceeded upper limit. + return nop, false + } + + if seq <= d.latestSeq { + if d.latestSeq >= uint64(d.windowSize)+seq { + return nop, false + } + if d.mask.Bit(uint(d.latestSeq-seq)) != 0 { + // The sequence number is duplicated. + return nop, false + } + } + + return func() bool { + latest := seq == 0 + if seq > d.latestSeq { + // Update the head of the window. + d.mask.Lsh(uint(seq - d.latestSeq)) + d.latestSeq = seq + latest = true + } + diff := (d.latestSeq - seq) % d.maxSeq + d.mask.SetBit(uint(diff)) + return latest + }, true +} + +// WithWrap creates ReplayDetector allowing sequence wrapping. +// This is suitable for short bit width counter like SRTP and SRTCP. +func WithWrap(windowSize uint, maxSeq uint64) ReplayDetector { + return &wrappedSlidingWindowDetector{ + maxSeq: maxSeq, + windowSize: windowSize, + mask: newFixedBigInt(windowSize), + } +} + +type wrappedSlidingWindowDetector struct { + latestSeq uint64 + maxSeq uint64 + windowSize uint + mask *fixedBigInt + init bool +} + +func (d *wrappedSlidingWindowDetector) Check(seq uint64) (func() bool, bool) { + if seq > d.maxSeq { + // Exceeded upper limit. + return nop, false + } + if !d.init { + if seq != 0 { + d.latestSeq = seq - 1 + } else { + d.latestSeq = d.maxSeq + } + d.init = true + } + + diff := int64(d.latestSeq) - int64(seq) + // Wrap the number. + if diff > int64(d.maxSeq)/2 { + diff -= int64(d.maxSeq + 1) + } else if diff <= -int64(d.maxSeq)/2 { + diff += int64(d.maxSeq + 1) + } + + if diff >= int64(d.windowSize) { + // Too old. + return nop, false + } + if diff >= 0 { + if d.mask.Bit(uint(diff)) != 0 { + // The sequence number is duplicated. + return nop, false + } + } + + return func() bool { + latest := false + if diff < 0 { + // Update the head of the window. + d.mask.Lsh(uint(-diff)) + d.latestSeq = seq + latest = true + } + d.mask.SetBit(uint(d.latestSeq - seq)) + return latest + }, true +} diff --git a/vendor/github.com/pion/transport/v3/udp/batchconn.go b/vendor/github.com/pion/transport/v3/udp/batchconn.go new file mode 100644 index 0000000000..54bdab65fe --- /dev/null +++ b/vendor/github.com/pion/transport/v3/udp/batchconn.go @@ -0,0 +1,170 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package udp + +import ( + "io" + "net" + "runtime" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// BatchWriter represents conn can write messages in batch +type BatchWriter interface { + WriteBatch(ms []ipv4.Message, flags int) (int, error) +} + +// BatchReader represents conn can read messages in batch +type BatchReader interface { + ReadBatch(msg []ipv4.Message, flags int) (int, error) +} + +// BatchPacketConn represents conn can read/write messages in batch +type BatchPacketConn interface { + BatchWriter + BatchReader + io.Closer +} + +// BatchConn uses ipv4/v6.NewPacketConn to wrap a net.PacketConn to write/read messages in batch, +// only available in linux. In other platform, it will use single Write/Read as same as net.Conn. +type BatchConn struct { + net.PacketConn + + batchConn BatchPacketConn + + batchWriteMutex sync.Mutex + batchWriteMessages []ipv4.Message + batchWritePos int + batchWriteLast time.Time + + batchWriteSize int + batchWriteInterval time.Duration + + closed int32 +} + +// NewBatchConn creates a *BatchConn from net.PacketConn with batch configs. +func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval time.Duration) *BatchConn { + bc := &BatchConn{ + PacketConn: conn, + batchWriteLast: time.Now(), + batchWriteInterval: batchWriteInterval, + batchWriteSize: batchWriteSize, + batchWriteMessages: make([]ipv4.Message, batchWriteSize), + } + for i := range bc.batchWriteMessages { + bc.batchWriteMessages[i].Buffers = [][]byte{make([]byte, sendMTU)} + } + + // batch write only supports linux + if runtime.GOOS == "linux" { + if pc4 := ipv4.NewPacketConn(conn); pc4 != nil { + bc.batchConn = pc4 + } else if pc6 := ipv6.NewPacketConn(conn); pc6 != nil { + bc.batchConn = pc6 + } + } + + if bc.batchConn != nil { + go func() { + writeTicker := time.NewTicker(batchWriteInterval / 2) + defer writeTicker.Stop() + for atomic.LoadInt32(&bc.closed) != 1 { + <-writeTicker.C + bc.batchWriteMutex.Lock() + if bc.batchWritePos > 0 && time.Since(bc.batchWriteLast) >= bc.batchWriteInterval { + _ = bc.flush() + } + bc.batchWriteMutex.Unlock() + } + }() + } + + return bc +} + +// Close batchConn and the underlying PacketConn +func (c *BatchConn) Close() error { + atomic.StoreInt32(&c.closed, 1) + c.batchWriteMutex.Lock() + if c.batchWritePos > 0 { + _ = c.flush() + } + c.batchWriteMutex.Unlock() + if c.batchConn != nil { + return c.batchConn.Close() + } + return c.PacketConn.Close() +} + +// WriteTo write message to an UDPAddr, addr should be nil if it is a connected socket. +func (c *BatchConn) WriteTo(b []byte, addr net.Addr) (int, error) { + if c.batchConn == nil { + return c.PacketConn.WriteTo(b, addr) + } + return c.enqueueMessage(b, addr) +} + +func (c *BatchConn) enqueueMessage(buf []byte, raddr net.Addr) (int, error) { + var err error + c.batchWriteMutex.Lock() + defer c.batchWriteMutex.Unlock() + + msg := &c.batchWriteMessages[c.batchWritePos] + // reset buffers + msg.Buffers = msg.Buffers[:1] + msg.Buffers[0] = msg.Buffers[0][:cap(msg.Buffers[0])] + + c.batchWritePos++ + if raddr != nil { + msg.Addr = raddr + } + if n := copy(msg.Buffers[0], buf); n < len(buf) { + extraBuffer := make([]byte, len(buf)-n) + copy(extraBuffer, buf[n:]) + msg.Buffers = append(msg.Buffers, extraBuffer) + } else { + msg.Buffers[0] = msg.Buffers[0][:n] + } + if c.batchWritePos == c.batchWriteSize { + err = c.flush() + } + return len(buf), err +} + +// ReadBatch reads messages in batch, return length of message readed and error. +func (c *BatchConn) ReadBatch(msgs []ipv4.Message, flags int) (int, error) { + if c.batchConn == nil { + n, addr, err := c.PacketConn.ReadFrom(msgs[0].Buffers[0]) + if err == nil { + msgs[0].N = n + msgs[0].Addr = addr + return 1, nil + } + return 0, err + } + return c.batchConn.ReadBatch(msgs, flags) +} + +func (c *BatchConn) flush() error { + var writeErr error + var txN int + for txN < c.batchWritePos { + n, err := c.batchConn.WriteBatch(c.batchWriteMessages[txN:c.batchWritePos], 0) + if err != nil { + writeErr = err + break + } + txN += n + } + c.batchWritePos = 0 + c.batchWriteLast = time.Now() + return writeErr +} diff --git a/vendor/github.com/pion/transport/v3/udp/conn.go b/vendor/github.com/pion/transport/v3/udp/conn.go new file mode 100644 index 0000000000..071b30e8e6 --- /dev/null +++ b/vendor/github.com/pion/transport/v3/udp/conn.go @@ -0,0 +1,389 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package udp provides a connection-oriented listener over a UDP PacketConn +package udp + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/pion/transport/v3/deadline" + "github.com/pion/transport/v3/packetio" + "golang.org/x/net/ipv4" +) + +const ( + receiveMTU = 8192 + sendMTU = 1500 + defaultListenBacklog = 128 // same as Linux default +) + +// Typed errors +var ( + ErrClosedListener = errors.New("udp: listener closed") + ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") + ErrInvalidBatchConfig = errors.New("udp: invalid batch config") +) + +// listener augments a connection-oriented Listener over a UDP PacketConn +type listener struct { + pConn net.PacketConn + + readBatchSize int + + accepting atomic.Value // bool + acceptCh chan *Conn + doneCh chan struct{} + doneOnce sync.Once + acceptFilter func([]byte) bool + + connLock sync.Mutex + conns map[string]*Conn + connWG *sync.WaitGroup + + readWG sync.WaitGroup + errClose atomic.Value // error + + readDoneCh chan struct{} + errRead atomic.Value // error +} + +// Accept waits for and returns the next connection to the listener. +func (l *listener) Accept() (net.Conn, error) { + select { + case c := <-l.acceptCh: + l.connWG.Add(1) + return c, nil + + case <-l.readDoneCh: + err, _ := l.errRead.Load().(error) + return nil, err + + case <-l.doneCh: + return nil, ErrClosedListener + } +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *listener) Close() error { + var err error + l.doneOnce.Do(func() { + l.accepting.Store(false) + close(l.doneCh) + + l.connLock.Lock() + // Close unaccepted connections + lclose: + for { + select { + case c := <-l.acceptCh: + close(c.doneCh) + delete(l.conns, c.rAddr.String()) + + default: + break lclose + } + } + nConns := len(l.conns) + l.connLock.Unlock() + + l.connWG.Done() + + if nConns == 0 { + // Wait if this is the final connection + l.readWG.Wait() + if errClose, ok := l.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + }) + + return err +} + +// Addr returns the listener's network address. +func (l *listener) Addr() net.Addr { + return l.pConn.LocalAddr() +} + +// BatchIOConfig indicates config to batch read/write packets, +// it will use ReadBatch/WriteBatch to improve throughput for UDP. +type BatchIOConfig struct { + Enable bool + // ReadBatchSize indicates the maximum number of packets to be read in one batch, a batch size less than 2 means + // disable read batch. + ReadBatchSize int + // WriteBatchSize indicates the maximum number of packets to be written in one batch + WriteBatchSize int + // WriteBatchInterval indicates the maximum interval to wait before writing packets in one batch + // small interval will reduce latency/jitter, but increase the io count. + WriteBatchInterval time.Duration +} + +// ListenConfig stores options for listening to an address. +type ListenConfig struct { + // Backlog defines the maximum length of the queue of pending + // connections. It is equivalent of the backlog argument of + // POSIX listen function. + // If a connection request arrives when the queue is full, + // the request will be silently discarded, unlike TCP. + // Set zero to use default value 128 which is same as Linux default. + Backlog int + + // AcceptFilter determines whether the new conn should be made for + // the incoming packet. If not set, any packet creates new conn. + AcceptFilter func([]byte) bool + + // ReadBufferSize sets the size of the operating system's + // receive buffer associated with the listener. + ReadBufferSize int + + // WriteBufferSize sets the size of the operating system's + // send buffer associated with the connection. + WriteBufferSize int + + Batch BatchIOConfig +} + +// Listen creates a new listener based on the ListenConfig. +func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener, error) { + if lc.Backlog == 0 { + lc.Backlog = defaultListenBacklog + } + + if lc.Batch.Enable && (lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) { + return nil, ErrInvalidBatchConfig + } + + conn, err := net.ListenUDP(network, laddr) + if err != nil { + return nil, err + } + + if lc.ReadBufferSize > 0 { + _ = conn.SetReadBuffer(lc.ReadBufferSize) + } + if lc.WriteBufferSize > 0 { + _ = conn.SetWriteBuffer(lc.WriteBufferSize) + } + + l := &listener{ + pConn: conn, + acceptCh: make(chan *Conn, lc.Backlog), + conns: make(map[string]*Conn), + doneCh: make(chan struct{}), + acceptFilter: lc.AcceptFilter, + connWG: &sync.WaitGroup{}, + readDoneCh: make(chan struct{}), + } + + if lc.Batch.Enable { + l.pConn = NewBatchConn(conn, lc.Batch.WriteBatchSize, lc.Batch.WriteBatchInterval) + l.readBatchSize = lc.Batch.ReadBatchSize + } + + l.accepting.Store(true) + l.connWG.Add(1) + l.readWG.Add(2) // wait readLoop and Close execution routine + + go l.readLoop() + go func() { + l.connWG.Wait() + if err := l.pConn.Close(); err != nil { + l.errClose.Store(err) + } + l.readWG.Done() + }() + + return l, nil +} + +// Listen creates a new listener using default ListenConfig. +func Listen(network string, laddr *net.UDPAddr) (net.Listener, error) { + return (&ListenConfig{}).Listen(network, laddr) +} + +// readLoop has to tasks: +// 1. Dispatching incoming packets to the correct Conn. +// It can therefore not be ended until all Conns are closed. +// 2. Creating a new Conn when receiving from a new remote. +func (l *listener) readLoop() { + defer l.readWG.Done() + defer close(l.readDoneCh) + + if br, ok := l.pConn.(BatchReader); ok && l.readBatchSize > 1 { + l.readBatch(br) + } else { + l.read() + } +} + +func (l *listener) readBatch(br BatchReader) { + msgs := make([]ipv4.Message, l.readBatchSize) + for i := range msgs { + msg := &msgs[i] + msg.Buffers = [][]byte{make([]byte, receiveMTU)} + msg.OOB = make([]byte, 40) + } + for { + n, err := br.ReadBatch(msgs, 0) + if err != nil { + l.errRead.Store(err) + return + } + for i := 0; i < n; i++ { + l.dispatchMsg(msgs[i].Addr, msgs[i].Buffers[0][:msgs[i].N]) + } + } +} + +func (l *listener) read() { + buf := make([]byte, receiveMTU) + for { + n, raddr, err := l.pConn.ReadFrom(buf) + if err != nil { + l.errRead.Store(err) + return + } + l.dispatchMsg(raddr, buf[:n]) + } +} + +func (l *listener) dispatchMsg(addr net.Addr, buf []byte) { + conn, ok, err := l.getConn(addr, buf) + if err != nil { + return + } + if ok { + _, _ = conn.buffer.Write(buf) + } +} + +func (l *listener) getConn(raddr net.Addr, buf []byte) (*Conn, bool, error) { + l.connLock.Lock() + defer l.connLock.Unlock() + conn, ok := l.conns[raddr.String()] + if !ok { + if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok { + return nil, false, ErrClosedListener + } + if l.acceptFilter != nil { + if !l.acceptFilter(buf) { + return nil, false, nil + } + } + conn = l.newConn(raddr) + select { + case l.acceptCh <- conn: + l.conns[raddr.String()] = conn + default: + return nil, false, ErrListenQueueExceeded + } + } + return conn, true, nil +} + +// Conn augments a connection-oriented connection over a UDP PacketConn +type Conn struct { + listener *listener + + rAddr net.Addr + + buffer *packetio.Buffer + + doneCh chan struct{} + doneOnce sync.Once + + writeDeadline *deadline.Deadline +} + +func (l *listener) newConn(rAddr net.Addr) *Conn { + return &Conn{ + listener: l, + rAddr: rAddr, + buffer: packetio.NewBuffer(), + doneCh: make(chan struct{}), + writeDeadline: deadline.New(), + } +} + +// Read reads from c into p +func (c *Conn) Read(p []byte) (int, error) { + return c.buffer.Read(p) +} + +// Write writes len(p) bytes from p to the DTLS connection +func (c *Conn) Write(p []byte) (n int, err error) { + select { + case <-c.writeDeadline.Done(): + return 0, context.DeadlineExceeded + default: + } + return c.listener.pConn.WriteTo(p, c.rAddr) +} + +// Close closes the conn and releases any Read calls +func (c *Conn) Close() error { + var err error + c.doneOnce.Do(func() { + c.listener.connWG.Done() + close(c.doneCh) + c.listener.connLock.Lock() + delete(c.listener.conns, c.rAddr.String()) + nConns := len(c.listener.conns) + c.listener.connLock.Unlock() + + if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok { + // Wait if this is the final connection + c.listener.readWG.Wait() + if errClose, ok := c.listener.errClose.Load().(error); ok { + err = errClose + } + } else { + err = nil + } + + if errBuf := c.buffer.Close(); errBuf != nil && err == nil { + err = errBuf + } + }) + + return err +} + +// LocalAddr implements net.Conn.LocalAddr +func (c *Conn) LocalAddr() net.Addr { + return c.listener.pConn.LocalAddr() +} + +// RemoteAddr implements net.Conn.RemoteAddr +func (c *Conn) RemoteAddr() net.Addr { + return c.rAddr +} + +// SetDeadline implements net.Conn.SetDeadline +func (c *Conn) SetDeadline(t time.Time) error { + c.writeDeadline.Set(t) + return c.SetReadDeadline(t) +} + +// SetReadDeadline implements net.Conn.SetDeadline +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.buffer.SetReadDeadline(t) +} + +// SetWriteDeadline implements net.Conn.SetDeadline +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.Set(t) + // Write deadline of underlying connection should not be changed + // since the connection can be shared. + return nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/.codecov.yml b/vendor/github.com/plgd-dev/go-coap/v3/.codecov.yml new file mode 100644 index 0000000000..3f1b1a655a --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/.codecov.yml @@ -0,0 +1,4 @@ +ignore: + - "examples" + - "**/*_test.go" + - "v3" diff --git a/vendor/github.com/plgd-dev/go-coap/v3/.gitignore b/vendor/github.com/plgd-dev/go-coap/v3/.gitignore new file mode 100644 index 0000000000..8d21c950b9 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/.gitignore @@ -0,0 +1,19 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +debug +server +!server/ +client +!client/ +vendor/ +v3/ + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out diff --git a/vendor/github.com/plgd-dev/go-coap/v3/.golangci.yml b/vendor/github.com/plgd-dev/go-coap/v3/.golangci.yml new file mode 100644 index 0000000000..2d5d6ae8fa --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/.golangci.yml @@ -0,0 +1,125 @@ +linters-settings: + govet: + check-shadowing: true + gocyclo: + min-complexity: 15 + +linters: + enable: + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + # - bodyclose # checks whether HTTP response body is closed successfully + # - contextcheck # check the function whether use a non-inherited context + - decorder # check declaration order and count of types, constants, variables and functions + # - depguard # Go linter that checks if package imports are in a list of acceptable packages + - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + # - exhaustive # check exhaustiveness of enum switch statements + - exportloopref # checks for pointers to enclosing loop variables + # - forbidigo # Forbids identifiers + - forcetypeassert # finds forced type assertions + - gci # Gci control golang package import order and make it always deterministic. + # - gochecknoglobals # Checks that no globals are present in Go code + # - gochecknoinits # Checks that no init functions are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # The most opinionated Go source code linter + - gocyclo # Computes and checks the cyclomatic complexity of functions + # - godox # Tool for detection of FIXME, TODO and other comment keywords + # - goerr113 # Golang linter to check the errors handling expressions + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + - gofumpt # Gofumpt checks whether code was gofumpt-ed. + - goheader # Checks is file header matches to pattern + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - goprintffuncname # Checks that printf-like functions are named with `f` at the end + - gosec # Inspects source code for security problems + - gosimple # Linter for Go source code that specializes in simplifying a code + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases + - ineffassign # Detects when assignments to existing variables are not used + - misspell # Finds commonly misspelled English words in comments + # - nakedret # Finds naked returns in functions greater than a specified function length + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. + # - noctx # noctx finds sending http request without context.Context + - nolintlint # Reports ill-formed or insufficient nolint directives + # - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - stylecheck # Stylecheck is a replacement for golint + # - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 + # - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - unused # Checks Go code for unused constants, variables, functions and types + # - wastedassign # wastedassign finds wasted assignment statements + - whitespace # Tool for detection of leading and trailing whitespace + disable: + - containedctx # containedctx is a linter that detects struct contained context.Context field + - cyclop # checks function and package cyclomatic complexity + - exhaustivestruct # Checks if all struct's fields are initialized + - funlen # Tool for detection of long functions + - godot # Check if comments end in a period + - gomnd # An analyzer to detect magic numbers. + - ifshort # Checks that your code uses short syntax for if-statements whenever possible + - ireturn # Accept Interfaces, Return Concrete Types + - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length + - maligned # Tool to detect Go structs that would take less memory if their fields were sorted + - nestif # Reports deeply nested if statements + - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity + - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - structcheck # Finds unused struct fields + - tagliatelle # Checks the struct tags. + - testpackage # linter that makes you use a separate _test package + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - varnamelen # checks that the length of a variable's name matches its scope + - wrapcheck # Checks that errors returned from external packages are wrapped + - wsl # Whitespace Linter - Forces you to use empty lines! + +issues: + # Excluding configuration per-path, per-linter, per-text and per-source + exclude-rules: + # Exclude some linters from running on tests files. + - path: _test\.go + linters: + - dupl + - forcetypeassert + - gosec + - gocyclo + - gocognit + + - path: ^test/.*\.go + linters: + - dupl + - forcetypeassert + - gosec + - gocyclo + - gocognit + + - path: example_test\.go + text: "exitAfterDefer" + linters: + - gocritic + + - path: pkg/rand/rand.go + text: "G404: Use of weak random number generator \\(math/rand instead of crypto/rand\\)" + linters: + - gosec + +# # Fix found issues (if it's supported by the linter). +# fix: true diff --git a/vendor/github.com/plgd-dev/go-coap/v3/LICENSE b/vendor/github.com/plgd-dev/go-coap/v3/LICENSE new file mode 100644 index 0000000000..261eeb9e9f --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/plgd-dev/go-coap/v3/README.md b/vendor/github.com/plgd-dev/go-coap/v3/README.md new file mode 100644 index 0000000000..65e6d794e4 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/README.md @@ -0,0 +1,153 @@ +# Go-CoAP + +[![Build Status](https://github.com/plgd-dev/go-coap/workflows/Test/badge.svg)](https://github.com/plgd-dev/go-coap/actions?query=workflow%3ATest) +[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=plgd-dev_go-coap&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=plgd-dev_go-coap) +[![Coverage](https://img.shields.io/sonar/coverage/plgd-dev_go-coap?server=https%3A%2F%2Fsonarcloud.io)](https://sonarcloud.io/summary/overall?id=plgd-dev_go-coap) +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fplgd-dev%2Fgo-coap.svg?type=shield)](https://app.fossa.io/projects/git%2Bgithub.com%2Fplgd-dev%2Fgo-coap?ref=badge_shield) +[![sponsors](https://opencollective.com/go-coap/sponsors/badge.svg)](https://opencollective.com/go-coap#sponsors) +[![contributors](https://img.shields.io/github/contributors/plgd-dev/go-coap)](https://github.com/plgd-dev/go-coap/graphs/contributors) +[![GitHub stars](https://img.shields.io/github/stars/plgd-dev/go-coap)](https://github.com/plgd-dev/go-coap/stargazers) +[![GitHub license](https://img.shields.io/github/license/plgd-dev/go-coap)](https://github.com/plgd-dev/go-coap/blob/master/LICENSE) +[![GoDoc](https://pkg.go.dev/badge/github.com/plgd-dev/go-coap/v3?utm_source=godoc)](https://pkg.go.dev/github.com/plgd-dev/go-coap/v3?utm_source=godoc) + + +The Constrained Application Protocol (CoAP) is a specialized web transfer protocol for use with constrained nodes and constrained networks in the Internet of Things. +The protocol is designed for machine-to-machine (M2M) applications such as smart energy and building automation. + +The go-coap provides servers and clients for DTLS, TCP-TLS, UDP, TCP in golang language. + +## Features + +* CoAP over UDP [RFC 7252][coap]. +* CoAP over TCP/TLS [RFC 8232][coap-tcp] +* Observe resources in CoAP [RFC 7641][coap-observe] +* Block-wise transfers in CoAP [RFC 7959][coap-block-wise-transfers] +* request multiplexer +* multicast +* CoAP NoResponse option in CoAP [RFC 7967][coap-noresponse] +* CoAP over DTLS [pion/dtls][pion-dtls] +* Too many requests response code [RFC 8516][coap-429] + +[coap]: http://tools.ietf.org/html/rfc7252 +[coap-tcp]: https://tools.ietf.org/html/rfc8323 +[coap-block-wise-transfers]: https://tools.ietf.org/html/rfc7959 +[coap-observe]: https://tools.ietf.org/html/rfc7641 +[coap-noresponse]: https://tools.ietf.org/html/rfc7967 +[pion-dtls]: https://github.com/pion/dtls +[coap-429]: https://datatracker.ietf.org/doc/html/rfc8516 + +## Requirements + +* Go 1.18 or higher + +## Samples + +### Simple + +#### Server UDP/TCP + +```go + // Server + + // Middleware function, which will be called for each request. + func loggingMiddleware(next mux.Handler) mux.Handler { + return mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + log.Printf("ClientAddress %v, %v\n", w.Conn().RemoteAddr(), r.String()) + next.ServeCOAP(w, r) + }) + } + + // See /examples/simple/server/main.go + func handleA(w mux.ResponseWriter, req *mux.Message) { + err := w.SetResponse(codes.GET, message.TextPlain, bytes.NewReader([]byte("hello world"))) + if err != nil { + log.Printf("cannot set response: %v", err) + } + } + + func main() { + r := mux.NewRouter() + r.Use(loggingMiddleware) + r.Handle("/a", mux.HandlerFunc(handleA)) + r.Handle("/b", mux.HandlerFunc(handleB)) + + log.Fatal(coap.ListenAndServe("udp", ":5688", r)) + + + // for tcp + // log.Fatal(coap.ListenAndServe("tcp", ":5688", r)) + + // for tcp-tls + // log.Fatal(coap.ListenAndServeTLS("tcp", ":5688", &tls.Config{...}, r)) + + // for udp-dtls + // log.Fatal(coap.ListenAndServeDTLS("udp", ":5688", &dtls.Config{...}, r)) + } +``` + +#### Client + +```go + // Client + // See /examples/simpler/client/main.go + func main() { + co, err := udp.Dial("localhost:5688") + + // for tcp + // co, err := tcp.Dial("localhost:5688") + + // for tcp-tls + // co, err := tcp.Dial("localhost:5688", tcp.WithTLS(&tls.Config{...})) + + // for dtls + // co, err := dtls.Dial("localhost:5688", &dtls.Config{...})) + + if err != nil { + log.Fatalf("Error dialing: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + resp, err := co.Get(ctx, "/a") + if err != nil { + log.Fatalf("Cannot get response: %v", err) + return + } + log.Printf("Response: %+v", resp) + } +``` + +### Observe / Notify + +[Server](examples/observe/server/main.go) example. + +[Client](examples/observe/client/main.go) example. + +### Multicast + +[Server](examples/mcast/server/main.go) example. + +[Client](examples/mcast/client/main.go) example. + +## License + +Apache 2.0 + +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fplgd-dev%2Fgo-coap.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2Fplgd-dev%2Fgo-coap?ref=badge_large) + + + +

Sponsors

+ +[Become a sponsor](https://opencollective.com/go-coap#sponsor) and get your logo on our README on Github with a link to your site. + +
+ + + +
+ +

Backers

+ +[Become a backer](https://opencollective.com/go-coap#backer) and get your image on our README on Github with a link to your site. + + diff --git a/vendor/github.com/plgd-dev/go-coap/v3/dtls/client.go b/vendor/github.com/plgd-dev/go-coap/v3/dtls/client.go new file mode 100644 index 0000000000..e2075f364a --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/dtls/client.go @@ -0,0 +1,127 @@ +package dtls + +import ( + "fmt" + "time" + + "github.com/pion/dtls/v2" + dtlsnet "github.com/pion/dtls/v2/pkg/net" + "github.com/plgd-dev/go-coap/v3/dtls/server" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options" + "github.com/plgd-dev/go-coap/v3/udp" + udpClient "github.com/plgd-dev/go-coap/v3/udp/client" +) + +var DefaultConfig = func() udpClient.Config { + cfg := udpClient.DefaultConfig + cfg.Handler = func(w *responsewriter.ResponseWriter[*udpClient.Conn], r *pool.Message) { + switch r.Code() { + case codes.POST, codes.PUT, codes.GET, codes.DELETE: + if err := w.SetResponse(codes.NotFound, message.TextPlain, nil); err != nil { + cfg.Errors(fmt.Errorf("dtls client: cannot set response: %w", err)) + } + } + } + return cfg +}() + +// Dial creates a client connection to the given target. +func Dial(target string, dtlsCfg *dtls.Config, opts ...udp.Option) (*udpClient.Conn, error) { + cfg := DefaultConfig + for _, o := range opts { + o.UDPClientApply(&cfg) + } + + c, err := cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target) + if err != nil { + return nil, err + } + + conn, err := dtls.Client(dtlsnet.PacketConnFromConn(c), c.RemoteAddr(), dtlsCfg) + if err != nil { + return nil, err + } + opts = append(opts, options.WithCloseSocket()) + return Client(conn, opts...), nil +} + +// Client creates client over dtls connection. +func Client(conn *dtls.Conn, opts ...udp.Option) *udpClient.Conn { + cfg := DefaultConfig + for _, o := range opts { + o.UDPClientApply(&cfg) + } + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() udpClient.InactivityMonitor { + return inactivity.NewNilMonitor[*udpClient.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + errorsFunc := cfg.Errors + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("dtls: %v: %w", conn.RemoteAddr(), err)) + } + + createBlockWise := func(cc *udpClient.Conn) *blockwise.BlockWise[*udpClient.Conn] { + return nil + } + if cfg.BlockwiseEnable { + createBlockWise = func(cc *udpClient.Conn) *blockwise.BlockWise[*udpClient.Conn] { + v := cc + return blockwise.New( + v, + cfg.BlockwiseTransferTimeout, + cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + return v.GetObservationRequest(token) + }, + ) + } + } + + monitor := cfg.CreateInactivityMonitor() + l := coapNet.NewConn(conn) + session := server.NewSession(cfg.Ctx, + l, + cfg.MaxMessageSize, + cfg.MTU, + cfg.CloseSocket, + ) + cc := udpClient.NewConn(session, + createBlockWise, + monitor, + &cfg, + ) + + cfg.PeriodicRunner(func(now time.Time) bool { + cc.CheckExpirations(now) + return cc.Context().Err() == nil + }) + + go func() { + err := cc.Run() + if err != nil { + cfg.Errors(err) + } + }() + + return cc +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/dtls/server.go b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server.go new file mode 100644 index 0000000000..881c844dad --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server.go @@ -0,0 +1,7 @@ +package dtls + +import "github.com/plgd-dev/go-coap/v3/dtls/server" + +func NewServer(opt ...server.Option) *server.Server { + return server.New(opt...) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/config.go b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/config.go new file mode 100644 index 0000000000..dba7297a7d --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/config.go @@ -0,0 +1,64 @@ +package server + +import ( + "fmt" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options/config" + udpClient "github.com/plgd-dev/go-coap/v3/udp/client" +) + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as COAP handlers. +type HandlerFunc = func(*responsewriter.ResponseWriter[*udpClient.Conn], *pool.Message) + +type ErrorFunc = func(error) + +// OnNewConnFunc is the callback for new connections. +type OnNewConnFunc = func(cc *udpClient.Conn) + +type GetMIDFunc = func() int32 + +var DefaultConfig = func() Config { + opts := Config{ + Common: config.NewCommon[*udpClient.Conn](), + CreateInactivityMonitor: func() udpClient.InactivityMonitor { + timeout := time.Second * 16 + onInactive := func(cc *udpClient.Conn) { + _ = cc.Close() + } + return inactivity.New(timeout, onInactive) + }, + OnNewConn: func(cc *udpClient.Conn) { + // do nothing by default + }, + TransmissionNStart: 1, + TransmissionAcknowledgeTimeout: time.Second * 2, + TransmissionMaxRetransmit: 4, + GetMID: message.GetMID, + MTU: udpClient.DefaultMTU, + } + opts.Handler = func(w *responsewriter.ResponseWriter[*udpClient.Conn], r *pool.Message) { + if err := w.SetResponse(codes.NotFound, message.TextPlain, nil); err != nil { + opts.Errors(fmt.Errorf("dtls server: cannot set response: %w", err)) + } + } + return opts +}() + +type Config struct { + config.Common[*udpClient.Conn] + CreateInactivityMonitor func() udpClient.InactivityMonitor + GetMID GetMIDFunc + Handler HandlerFunc + OnNewConn OnNewConnFunc + TransmissionNStart uint32 + TransmissionAcknowledgeTimeout time.Duration + TransmissionMaxRetransmit uint32 + MTU uint16 +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/server.go b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/server.go new file mode 100644 index 0000000000..01587f854d --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/server.go @@ -0,0 +1,231 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/pkg/connections" + udpClient "github.com/plgd-dev/go-coap/v3/udp/client" +) + +// Listener defined used by coap +type Listener interface { + Close() error + AcceptWithContext(ctx context.Context) (net.Conn, error) +} + +type Server struct { + ctx context.Context + cancel context.CancelFunc + cfg *Config + + listenMutex sync.Mutex + listen Listener +} + +// A Option sets options such as credentials, codec and keepalive parameters, etc. +type Option interface { + DTLSServerApply(cfg *Config) +} + +func New(opt ...Option) *Server { + cfg := DefaultConfig + for _, o := range opt { + o.DTLSServerApply(&cfg) + } + + ctx, cancel := context.WithCancel(cfg.Ctx) + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + + if cfg.GetMID == nil { + cfg.GetMID = message.GetMID + } + + if cfg.GetToken == nil { + cfg.GetToken = message.GetToken + } + + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() udpClient.InactivityMonitor { + return inactivity.NewNilMonitor[*udpClient.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + + errorsFunc := cfg.Errors + // assign updated func to cfg.errors so cfg.handler also uses the updated error handler + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("dtls: %w", err)) + } + + return &Server{ + ctx: ctx, + cancel: cancel, + cfg: &cfg, + } +} + +func (s *Server) checkAndSetListener(l Listener) error { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + if s.listen != nil { + return fmt.Errorf("server already serve listener") + } + s.listen = l + return nil +} + +func (s *Server) checkAcceptError(err error) bool { + if err == nil { + return true + } + switch { + case errors.Is(err, coapNet.ErrListenerIsClosed): + s.Stop() + return false + case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled): + select { + case <-s.ctx.Done(): + default: + s.cfg.Errors(fmt.Errorf("cannot accept connection: %w", err)) + return true + } + return false + default: + return true + } +} + +func (s *Server) serveConnection(connections *connections.Connections, cc *udpClient.Conn) { + connections.Store(cc) + defer connections.Delete(cc) + + if err := cc.Run(); err != nil { + s.cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err)) + } +} + +func (s *Server) Serve(l Listener) error { + if s.cfg.BlockwiseSZX > blockwise.SZX1024 { + return fmt.Errorf("invalid blockwiseSZX") + } + err := s.checkAndSetListener(l) + if err != nil { + return err + } + defer func() { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + s.listen = nil + }() + + var wg sync.WaitGroup + defer wg.Wait() + + connections := connections.New() + s.cfg.PeriodicRunner(func(now time.Time) bool { + connections.CheckExpirations(now) + return s.ctx.Err() == nil + }) + defer connections.Close() + + for { + rw, err := l.AcceptWithContext(s.ctx) + if ok := s.checkAcceptError(err); !ok { + return nil + } + if rw == nil { + continue + } + wg.Add(1) + var cc *udpClient.Conn + monitor := s.cfg.CreateInactivityMonitor() + cc = s.createConn(coapNet.NewConn(rw), monitor) + if s.cfg.OnNewConn != nil { + s.cfg.OnNewConn(cc) + } + go func() { + defer wg.Done() + s.serveConnection(connections, cc) + }() + } +} + +// Stop stops server without wait of ends Serve function. +func (s *Server) Stop() { + s.cancel() + s.listenMutex.Lock() + l := s.listen + s.listen = nil + s.listenMutex.Unlock() + if l != nil { + if err := l.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err)) + } + } +} + +func (s *Server) createConn(connection *coapNet.Conn, monitor udpClient.InactivityMonitor) *udpClient.Conn { + createBlockWise := func(cc *udpClient.Conn) *blockwise.BlockWise[*udpClient.Conn] { + return nil + } + if s.cfg.BlockwiseEnable { + createBlockWise = func(cc *udpClient.Conn) *blockwise.BlockWise[*udpClient.Conn] { + v := cc + return blockwise.New( + v, + s.cfg.BlockwiseTransferTimeout, + s.cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + return v.GetObservationRequest(token) + }, + ) + } + } + session := NewSession( + s.ctx, + connection, + s.cfg.MaxMessageSize, + s.cfg.MTU, + true, + ) + cfg := udpClient.DefaultConfig + cfg.TransmissionNStart = s.cfg.TransmissionNStart + cfg.TransmissionAcknowledgeTimeout = s.cfg.TransmissionAcknowledgeTimeout + cfg.TransmissionMaxRetransmit = s.cfg.TransmissionMaxRetransmit + cfg.Handler = s.cfg.Handler + cfg.BlockwiseSZX = s.cfg.BlockwiseSZX + cfg.Errors = s.cfg.Errors + cfg.GetMID = s.cfg.GetMID + cfg.GetToken = s.cfg.GetToken + cfg.MessagePool = s.cfg.MessagePool + cfg.ReceivedMessageQueueSize = s.cfg.ReceivedMessageQueueSize + cfg.ProcessReceivedMessage = s.cfg.ProcessReceivedMessage + cc := udpClient.NewConn( + session, + createBlockWise, + monitor, + &cfg, + ) + + return cc +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/session.go b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/session.go new file mode 100644 index 0000000000..9eb2d001a5 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/dtls/server/session.go @@ -0,0 +1,159 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/udp/client" + "github.com/plgd-dev/go-coap/v3/udp/coder" +) + +type EventFunc = func() + +type Session struct { + onClose []EventFunc + + ctx atomic.Value // TODO: change to atomic.Pointer[context.Context] for go1.19 + + cancel context.CancelFunc + connection *coapNet.Conn + + done chan struct{} + + mutex sync.Mutex + + maxMessageSize uint32 + + mtu uint16 + + closeSocket bool +} + +func NewSession( + ctx context.Context, + connection *coapNet.Conn, + maxMessageSize uint32, + mtu uint16, + closeSocket bool, +) *Session { + ctx, cancel := context.WithCancel(ctx) + s := &Session{ + cancel: cancel, + connection: connection, + maxMessageSize: maxMessageSize, + closeSocket: closeSocket, + mtu: mtu, + done: make(chan struct{}), + } + s.ctx.Store(&ctx) + return s +} + +// Done signalizes that connection is not more processed. +func (s *Session) Done() <-chan struct{} { + return s.done +} + +func (s *Session) AddOnClose(f EventFunc) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.onClose = append(s.onClose, f) +} + +func (s *Session) popOnClose() []EventFunc { + s.mutex.Lock() + defer s.mutex.Unlock() + tmp := s.onClose + s.onClose = nil + return tmp +} + +func (s *Session) shutdown() { + defer close(s.done) + for _, f := range s.popOnClose() { + f() + } +} + +func (s *Session) Close() error { + s.cancel() + if s.closeSocket { + return s.connection.Close() + } + return nil +} + +func (s *Session) Context() context.Context { + return *s.ctx.Load().(*context.Context) //nolint:forcetypeassert +} + +// SetContextValue stores the value associated with key to context of connection. +func (s *Session) SetContextValue(key interface{}, val interface{}) { + ctx := context.WithValue(s.Context(), key, val) + s.ctx.Store(&ctx) +} + +func (s *Session) WriteMessage(req *pool.Message) error { + data, err := req.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + return fmt.Errorf("cannot marshal: %w", err) + } + err = s.connection.WriteWithContext(req.Context(), data) + if err != nil { + return fmt.Errorf("cannot write to connection: %w", err) + } + return err +} + +// WriteMulticastMessage sends multicast to the remote multicast address. +// Currently it is not implemented - is is just satisfy golang udp/client/Session interface. +func (s *Session) WriteMulticastMessage(*pool.Message, *net.UDPAddr, ...coapNet.MulticastOption) error { + return errors.New("multicast messages not implemented for DTLS") +} + +func (s *Session) MaxMessageSize() uint32 { + return s.maxMessageSize +} + +func (s *Session) RemoteAddr() net.Addr { + return s.connection.RemoteAddr() +} + +func (s *Session) LocalAddr() net.Addr { + return s.connection.LocalAddr() +} + +// Run reads and process requests from a connection, until the connection is not closed. +func (s *Session) Run(cc *client.Conn) (err error) { + defer func() { + err1 := s.Close() + if err == nil { + err = err1 + } + s.shutdown() + }() + m := make([]byte, s.mtu) + for { + readBuf := m + readLen, err := s.connection.ReadWithContext(s.Context(), readBuf) + if err != nil { + return fmt.Errorf("cannot read from connection: %w", err) + } + readBuf = readBuf[:readLen] + err = cc.Process(readBuf) + if err != nil { + return err + } + } +} + +// NetConn returns the underlying connection that is wrapped by s. The Conn returned is shared by all invocations of NetConn, so do not modify it. +func (s *Session) NetConn() net.Conn { + return s.connection.NetConn() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/codes/code_string.go b/vendor/github.com/plgd-dev/go-coap/v3/message/codes/code_string.go new file mode 100644 index 0000000000..79f248a756 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/codes/code_string.go @@ -0,0 +1,58 @@ +package codes + +import ( + "fmt" + "strconv" +) + +var codeToString = map[Code]string{ + Empty: "Empty", + GET: "GET", + POST: "POST", + PUT: "PUT", + DELETE: "DELETE", + Created: "Created", + Deleted: "Deleted", + Valid: "Valid", + Changed: "Changed", + Content: "Content", + BadRequest: "BadRequest", + Unauthorized: "Unauthorized", + BadOption: "BadOption", + Forbidden: "Forbidden", + NotFound: "NotFound", + MethodNotAllowed: "MethodNotAllowed", + NotAcceptable: "NotAcceptable", + PreconditionFailed: "PreconditionFailed", + RequestEntityTooLarge: "RequestEntityTooLarge", + UnsupportedMediaType: "UnsupportedMediaType", + TooManyRequests: "TooManyRequests", + InternalServerError: "InternalServerError", + NotImplemented: "NotImplemented", + BadGateway: "BadGateway", + ServiceUnavailable: "ServiceUnavailable", + GatewayTimeout: "GatewayTimeout", + ProxyingNotSupported: "ProxyingNotSupported", + CSM: "Capabilities and Settings Messages", + Ping: "Ping", + Pong: "Pong", + Release: "Release", + Abort: "Abort", +} + +func (c Code) String() string { + val, ok := codeToString[c] + if ok { + return val + } + return "Code(" + strconv.FormatInt(int64(c), 10) + ")" +} + +func ToCode(v string) (Code, error) { + for key, val := range codeToString { + if v == val { + return key, nil + } + } + return 0, fmt.Errorf("not found") +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/codes/codes.go b/vendor/github.com/plgd-dev/go-coap/v3/message/codes/codes.go new file mode 100644 index 0000000000..62e3d8c38c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/codes/codes.go @@ -0,0 +1,119 @@ +package codes + +import ( + "fmt" + "strconv" +) + +// A Code is an unsigned 16-bit coap code as defined in the coap spec. +type Code uint16 + +// Request Codes +const ( + GET Code = 1 + POST Code = 2 + PUT Code = 3 + DELETE Code = 4 +) + +// Response Codes +const ( + Empty Code = 0 + Created Code = 65 + Deleted Code = 66 + Valid Code = 67 + Changed Code = 68 + Content Code = 69 + Continue Code = 95 + BadRequest Code = 128 + Unauthorized Code = 129 + BadOption Code = 130 + Forbidden Code = 131 + NotFound Code = 132 + MethodNotAllowed Code = 133 + NotAcceptable Code = 134 + RequestEntityIncomplete Code = 136 + PreconditionFailed Code = 140 + RequestEntityTooLarge Code = 141 + UnsupportedMediaType Code = 143 + TooManyRequests Code = 157 + InternalServerError Code = 160 + NotImplemented Code = 161 + BadGateway Code = 162 + ServiceUnavailable Code = 163 + GatewayTimeout Code = 164 + ProxyingNotSupported Code = 165 +) + +// Signaling Codes for TCP +const ( + CSM Code = 225 + Ping Code = 226 + Pong Code = 227 + Release Code = 228 + Abort Code = 229 +) + +const _maxCode = 255 + +var strToCode = map[string]Code{ + `"GET"`: GET, + `"POST"`: POST, + `"PUT"`: PUT, + `"DELETE"`: DELETE, + `"Created"`: Created, + `"Deleted"`: Deleted, + `"Valid"`: Valid, + `"Changed"`: Changed, + `"Content"`: Content, + `"BadRequest"`: BadRequest, + `"Unauthorized"`: Unauthorized, + `"BadOption"`: BadOption, + `"Forbidden"`: Forbidden, + `"NotFound"`: NotFound, + `"MethodNotAllowed"`: MethodNotAllowed, + `"NotAcceptable"`: NotAcceptable, + `"PreconditionFailed"`: PreconditionFailed, + `"RequestEntityTooLarge"`: RequestEntityTooLarge, + `"UnsupportedMediaType"`: UnsupportedMediaType, + `"TooManyRequests"`: TooManyRequests, + `"InternalServerError"`: InternalServerError, + `"NotImplemented"`: NotImplemented, + `"BadGateway"`: BadGateway, + `"ServiceUnavailable"`: ServiceUnavailable, + `"GatewayTimeout"`: GatewayTimeout, + `"ProxyingNotSupported"`: ProxyingNotSupported, + `"Capabilities and Settings Messages"`: CSM, + `"Ping"`: Ping, + `"Pong"`: Pong, + `"Release"`: Release, + `"Abort"`: Abort, +} + +// UnmarshalJSON unmarshals b into the Code. +func (c *Code) UnmarshalJSON(b []byte) error { + // From json.Unmarshaler: By convention, to approximate the behavior of + // Unmarshal itself, Unmarshalers implement UnmarshalJSON([]byte("null")) as + // a no-op. + if string(b) == "null" { + return nil + } + if c == nil { + return fmt.Errorf("nil receiver passed to UnmarshalJSON") + } + + if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil { + if ci >= _maxCode { + return fmt.Errorf("invalid code: %q", ci) + } + + *c = Code(ci) + return nil + } + + if jc, ok := strToCode[string(b)]; ok { + *c = jc + return nil + } + return fmt.Errorf("invalid code: %q", string(b)) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/encodeDecodeUint32.go b/vendor/github.com/plgd-dev/go-coap/v3/message/encodeDecodeUint32.go new file mode 100644 index 0000000000..51355c39b3 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/encodeDecodeUint32.go @@ -0,0 +1,48 @@ +package message + +import ( + "encoding/binary" +) + +func EncodeUint32(buf []byte, value uint32) (int, error) { + switch { + case value == 0: + return 0, nil + case value <= max1ByteNumber: + if len(buf) < 1 { + return 1, ErrTooSmall + } + buf[0] = byte(value) + return 1, nil + case value <= max2ByteNumber: + if len(buf) < 2 { + return 2, ErrTooSmall + } + binary.BigEndian.PutUint16(buf, uint16(value)) + return 2, nil + case value <= max3ByteNumber: + if len(buf) < 3 { + return 3, ErrTooSmall + } + rv := make([]byte, 4) + binary.BigEndian.PutUint32(rv, value) + copy(buf, rv[1:]) + return 3, nil + default: + if len(buf) < 4 { + return 4, ErrTooSmall + } + binary.BigEndian.PutUint32(buf, value) + return 4, nil + } +} + +func DecodeUint32(buf []byte) (uint32, int, error) { + if len(buf) > 4 { + buf = buf[:4] + } + tmp := []byte{0, 0, 0, 0} + copy(tmp[4-len(buf):], buf) + value := binary.BigEndian.Uint32(tmp) + return value, len(buf), nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/error.go b/vendor/github.com/plgd-dev/go-coap/v3/message/error.go new file mode 100644 index 0000000000..267cdb077e --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/error.go @@ -0,0 +1,17 @@ +package message + +import "errors" + +var ( + ErrTooSmall = errors.New("too small bytes buffer") + ErrInvalidOptionHeaderExt = errors.New("invalid option header ext") + ErrInvalidTokenLen = errors.New("invalid token length") + ErrInvalidValueLength = errors.New("invalid value length") + ErrShortRead = errors.New("invalid short read") + ErrOptionTruncated = errors.New("option truncated") + ErrOptionUnexpectedExtendMarker = errors.New("option unexpected extend marker") + ErrOptionsTooSmall = errors.New("too small options buffer") + ErrInvalidEncoding = errors.New("invalid encoding") + ErrOptionNotFound = errors.New("option not found") + ErrOptionDuplicate = errors.New("duplicated option") +) diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/getETag.go b/vendor/github.com/plgd-dev/go-coap/v3/message/getETag.go new file mode 100644 index 0000000000..0b09623034 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/getETag.go @@ -0,0 +1,44 @@ +package message + +import ( + "encoding/binary" + "errors" + "hash/crc64" + "io" +) + +// GetETag calculates ETag from payload via CRC64 +func GetETag(r io.ReadSeeker) ([]byte, error) { + if r == nil { + return make([]byte, 8), nil + } + c64 := crc64.New(crc64.MakeTable(crc64.ISO)) + orig, err := r.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + _, err = r.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + buf := make([]byte, 4096) + for { + bufR := buf + n, errR := r.Read(bufR) + if errors.Is(errR, io.EOF) { + break + } + if errR != nil { + return nil, errR + } + bufR = bufR[:n] + c64.Write(bufR) + } + _, err = r.Seek(orig, io.SeekStart) + if err != nil { + return nil, err + } + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, c64.Sum64()) + return b, nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/getToken.go b/vendor/github.com/plgd-dev/go-coap/v3/message/getToken.go new file mode 100644 index 0000000000..4ceedde7db --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/getToken.go @@ -0,0 +1,29 @@ +package message + +import ( + "crypto/rand" + "encoding/hex" + "hash/crc64" +) + +type Token []byte + +func (t Token) String() string { + return hex.EncodeToString(t) +} + +func (t Token) Hash() uint64 { + return crc64.Checksum(t, crc64.MakeTable(crc64.ISO)) +} + +// GetToken generates a random token by a given length +func GetToken() (Token, error) { + b := make(Token, 8) + _, err := rand.Read(b) + // Note that err == nil only if we read len(b) bytes. + if err != nil { + return nil, err + } + + return b, nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/getmid.go b/vendor/github.com/plgd-dev/go-coap/v3/message/getmid.go new file mode 100644 index 0000000000..e049381e4a --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/getmid.go @@ -0,0 +1,35 @@ +package message + +import ( + "crypto/rand" + "encoding/binary" + "math" + "sync/atomic" + "time" + + pkgRand "github.com/plgd-dev/go-coap/v3/pkg/rand" +) + +var weakRng = pkgRand.NewRand(time.Now().UnixNano()) + +var msgID = uint32(RandMID()) + +// GetMID generates a message id for UDP. (0 <= mid <= 65535) +func GetMID() int32 { + return int32(uint16(atomic.AddUint32(&msgID, 1))) +} + +func RandMID() int32 { + b := make([]byte, 4) + _, err := rand.Read(b) + if err != nil { + // fallback to cryptographically insecure pseudo-random generator + return int32(uint16(weakRng.Uint32() >> 16)) + } + return int32(uint16(binary.BigEndian.Uint32(b))) +} + +// ValidateMID validates a message id for UDP. (0 <= mid <= 65535) +func ValidateMID(mid int32) bool { + return mid >= 0 && mid <= math.MaxUint16 +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/message.go b/vendor/github.com/plgd-dev/go-coap/v3/message/message.go new file mode 100644 index 0000000000..c68ce5f36d --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/message.go @@ -0,0 +1,50 @@ +package message + +import ( + "fmt" + + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +// MaxTokenSize maximum of token size that can be used in message +const MaxTokenSize = 8 + +type Message struct { + Token Token + Options Options + Code codes.Code + Payload []byte + + // For DTLS and UDP messages + MessageID int32 // uint16 is valid, all other values are invalid, -1 is used for unset + Type Type // uint8 is valid, all other values are invalid, -1 is used for unset +} + +func (r *Message) String() string { + if r == nil { + return "nil" + } + buf := fmt.Sprintf("Code: %v, Token: %v", r.Code, r.Token) + path, err := r.Options.Path() + if err == nil { + buf = fmt.Sprintf("%s, Path: %v", buf, path) + } + cf, err := r.Options.ContentFormat() + if err == nil { + buf = fmt.Sprintf("%s, ContentFormat: %v", buf, cf) + } + queries, err := r.Options.Queries() + if err == nil { + buf = fmt.Sprintf("%s, Queries: %+v", buf, queries) + } + if ValidateType(r.Type) { + buf = fmt.Sprintf("%s, Type: %v", buf, r.Type) + } + if ValidateMID(r.MessageID) { + buf = fmt.Sprintf("%s, MessageID: %v", buf, r.MessageID) + } + if len(r.Payload) > 0 { + buf = fmt.Sprintf("%s, PayloadLen: %v", buf, len(r.Payload)) + } + return buf +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/noresponse/error.go b/vendor/github.com/plgd-dev/go-coap/v3/message/noresponse/error.go new file mode 100644 index 0000000000..96d4dc9756 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/noresponse/error.go @@ -0,0 +1,6 @@ +package noresponse + +import "errors" + +// ErrMessageNotInterested message is not of interest to the client +var ErrMessageNotInterested = errors.New("message not to be sent due to disinterest") diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/noresponse/noresponse.go b/vendor/github.com/plgd-dev/go-coap/v3/message/noresponse/noresponse.go new file mode 100644 index 0000000000..8bd0e52dec --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/noresponse/noresponse.go @@ -0,0 +1,52 @@ +package noresponse + +import ( + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +var ( + resp2XXCodes = []codes.Code{codes.Created, codes.Deleted, codes.Valid, codes.Changed, codes.Content} + resp4XXCodes = []codes.Code{codes.BadRequest, codes.Unauthorized, codes.BadOption, codes.Forbidden, codes.NotFound, codes.MethodNotAllowed, codes.NotAcceptable, codes.PreconditionFailed, codes.RequestEntityTooLarge, codes.UnsupportedMediaType} + resp5XXCodes = []codes.Code{codes.InternalServerError, codes.NotImplemented, codes.BadGateway, codes.ServiceUnavailable, codes.GatewayTimeout, codes.ProxyingNotSupported} + noResponseValueMap = map[uint32][]codes.Code{ + 2: resp2XXCodes, + 8: resp4XXCodes, + 16: resp5XXCodes, + } +) + +func isSet(n uint32, pos uint32) bool { + val := n & (1 << pos) + return (val > 0) +} + +func decodeNoResponseOption(v uint32) []codes.Code { + var codes []codes.Code + if v == 0 { + // No suppresed code + return codes + } + + var i uint32 + // Max bit value:4; ref:table_2_rfc7967 + for i = 0; i <= 4; i++ { + if isSet(v, i) { + index := uint32(1 << i) + codes = append(codes, noResponseValueMap[index]...) + } + } + return codes +} + +// IsNoResponseCode validates response code against NoResponse option from request. +// https://www.rfc-editor.org/rfc/rfc7967.txt +func IsNoResponseCode(code codes.Code, noRespValue uint32) error { + suppressedCodes := decodeNoResponseOption(noRespValue) + + for _, suppressedCode := range suppressedCodes { + if suppressedCode == code { + return ErrMessageNotInterested + } + } + return nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/option.go b/vendor/github.com/plgd-dev/go-coap/v3/message/option.go new file mode 100644 index 0000000000..1a692f12c9 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/option.go @@ -0,0 +1,450 @@ +package message + +import ( + "encoding/binary" + "errors" + "fmt" + "strconv" +) + +const ( + max1ByteNumber = uint32(^uint8(0)) + max2ByteNumber = uint32(^uint16(0)) + max3ByteNumber = uint32(0xffffff) +) + +const ( + ExtendOptionByteCode = 13 + ExtendOptionByteAddend = 13 + ExtendOptionWordCode = 14 + ExtendOptionWordAddend = 269 + ExtendOptionError = 15 +) + +// OptionID identifies an option in a message. +type OptionID uint16 + +/* + +-----+----+---+---+---+----------------+--------+--------+---------+ + | No. | C | U | N | R | Name | Format | Length | Default | + +-----+----+---+---+---+----------------+--------+--------+---------+ + | 1 | x | | | x | If-Match | opaque | 0-8 | (none) | + | 3 | x | x | - | | Uri-Host | string | 1-255 | (see | + | | | | | | | | | below) | + | 4 | | | | x | ETag | opaque | 1-8 | (none) | + | 5 | x | | | | If-None-Match | empty | 0 | (none) | + | 7 | x | x | - | | Uri-Port | uint | 0-2 | (see | + | | | | | | | | | below) | + | 8 | | | | x | Location-Path | string | 0-255 | (none) | + | 11 | x | x | - | x | Uri-Path | string | 0-255 | (none) | + | 12 | | | | | Content-Format | uint | 0-2 | (none) | + | 14 | | x | - | | Max-Age | uint | 0-4 | 60 | + | 15 | x | x | - | x | Uri-Query | string | 0-255 | (none) | + | 17 | x | | | | Accept | uint | 0-2 | (none) | + | 20 | | | | x | Location-Query | string | 0-255 | (none) | + | 23 | x | x | - | - | Block2 | uint | 0-3 | (none) | + | 27 | x | x | - | - | Block1 | uint | 0-3 | (none) | + | 28 | | | x | | Size2 | uint | 0-4 | (none) | + | 35 | x | x | - | | Proxy-Uri | string | 1-1034 | (none) | + | 39 | x | x | - | | Proxy-Scheme | string | 1-255 | (none) | + | 60 | | | x | | Size1 | uint | 0-4 | (none) | + +-----+----+---+---+---+----------------+--------+--------+---------+ + C=Critical, U=Unsafe, N=NoCacheKey, R=Repeatable +*/ + +// Option IDs. +const ( + IfMatch OptionID = 1 + URIHost OptionID = 3 + ETag OptionID = 4 + IfNoneMatch OptionID = 5 + Observe OptionID = 6 + URIPort OptionID = 7 + LocationPath OptionID = 8 + URIPath OptionID = 11 + ContentFormat OptionID = 12 + MaxAge OptionID = 14 + URIQuery OptionID = 15 + Accept OptionID = 17 + LocationQuery OptionID = 20 + Block2 OptionID = 23 + Block1 OptionID = 27 + Size2 OptionID = 28 + ProxyURI OptionID = 35 + ProxyScheme OptionID = 39 + Size1 OptionID = 60 + NoResponse OptionID = 258 +) + +var optionIDToString = map[OptionID]string{ + IfMatch: "IfMatch", + URIHost: "URIHost", + ETag: "ETag", + IfNoneMatch: "IfNoneMatch", + Observe: "Observe", + URIPort: "URIPort", + LocationPath: "LocationPath", + URIPath: "URIPath", + ContentFormat: "ContentFormat", + MaxAge: "MaxAge", + URIQuery: "URIQuery", + Accept: "Accept", + LocationQuery: "LocationQuery", + Block2: "Block2", + Block1: "Block1", + Size2: "Size2", + ProxyURI: "ProxyURI", + ProxyScheme: "ProxyScheme", + Size1: "Size1", + NoResponse: "NoResponse", +} + +func (o OptionID) String() string { + str, ok := optionIDToString[o] + if !ok { + return "Option(" + strconv.FormatInt(int64(o), 10) + ")" + } + return str +} + +func ToOptionID(v string) (OptionID, error) { + for key, val := range optionIDToString { + if val == v { + return key, nil + } + } + return 0, fmt.Errorf("not found") +} + +// Option value format (RFC7252 section 3.2) +type ValueFormat uint8 + +const ( + ValueUnknown ValueFormat = iota + ValueEmpty + ValueOpaque + ValueUint + ValueString +) + +type OptionDef struct { + MinLen uint32 + MaxLen uint32 + ValueFormat ValueFormat +} + +var CoapOptionDefs = map[OptionID]OptionDef{ + IfMatch: {ValueFormat: ValueOpaque, MinLen: 0, MaxLen: 8}, + URIHost: {ValueFormat: ValueString, MinLen: 1, MaxLen: 255}, + ETag: {ValueFormat: ValueOpaque, MinLen: 1, MaxLen: 8}, + IfNoneMatch: {ValueFormat: ValueEmpty, MinLen: 0, MaxLen: 0}, + Observe: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 3}, + URIPort: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 2}, + LocationPath: {ValueFormat: ValueString, MinLen: 0, MaxLen: 255}, + URIPath: {ValueFormat: ValueString, MinLen: 0, MaxLen: 255}, + ContentFormat: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 2}, + MaxAge: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 4}, + URIQuery: {ValueFormat: ValueString, MinLen: 0, MaxLen: 255}, + Accept: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 2}, + LocationQuery: {ValueFormat: ValueString, MinLen: 0, MaxLen: 255}, + Block2: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 3}, + Block1: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 3}, + Size2: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 4}, + ProxyURI: {ValueFormat: ValueString, MinLen: 1, MaxLen: 1034}, + ProxyScheme: {ValueFormat: ValueString, MinLen: 1, MaxLen: 255}, + Size1: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 4}, + NoResponse: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 1}, +} + +// MediaType specifies the content format of a message. +type MediaType uint16 + +// Content formats. +var ( + TextPlain MediaType // text/plain;charset=utf-8 + AppCoseEncrypt0 MediaType = 16 // application/cose; cose-type="cose-encrypt0" (RFC 8152) + AppCoseMac0 MediaType = 17 // application/cose; cose-type="cose-mac0" (RFC 8152) + AppCoseSign1 MediaType = 18 // application/cose; cose-type="cose-sign1" (RFC 8152) + AppLinkFormat MediaType = 40 // application/link-format + AppXML MediaType = 41 // application/xml + AppOctets MediaType = 42 // application/octet-stream + AppExi MediaType = 47 // application/exi + AppJSON MediaType = 50 // application/json + AppJSONPatch MediaType = 51 // application/json-patch+json (RFC6902) + AppJSONMergePatch MediaType = 52 // application/merge-patch+json (RFC7396) + AppCBOR MediaType = 60 // application/cbor (RFC 7049) + AppCWT MediaType = 61 // application/cwt + AppCoseEncrypt MediaType = 96 // application/cose; cose-type="cose-encrypt" (RFC 8152) + AppCoseMac MediaType = 97 // application/cose; cose-type="cose-mac" (RFC 8152) + AppCoseSign MediaType = 98 // application/cose; cose-type="cose-sign" (RFC 8152) + AppCoseKey MediaType = 101 // application/cose-key (RFC 8152) + AppCoseKeySet MediaType = 102 // application/cose-key-set (RFC 8152) + AppSenmlJSON MediaType = 110 // application/senml+json + AppSenmlCbor MediaType = 112 // application/senml+cbor + AppCoapGroup MediaType = 256 // coap-group+json (RFC 7390) + AppSenmlEtchJSON MediaType = 320 // application/senml-etch+json + AppSenmlEtchCbor MediaType = 322 // application/senml-etch+cbor + AppOcfCbor MediaType = 10000 // application/vnd.ocf+cbor + AppLwm2mTLV MediaType = 11542 // application/vnd.oma.lwm2m+tlv + AppLwm2mJSON MediaType = 11543 // application/vnd.oma.lwm2m+json + AppLwm2mCbor MediaType = 11544 // application/vnd.oma.lwm2m+cbor +) + +var mediaTypeToString = map[MediaType]string{ + TextPlain: "text/plain;charset=utf-8", + AppCoseEncrypt0: "application/cose; cose-type=\"cose-encrypt0\" (RFC 8152)", + AppCoseMac0: "application/cose; cose-type=\"cose-mac0\" (RFC 8152)", + AppCoseSign1: "application/cose; cose-type=\"cose-sign1\" (RFC 8152)", + AppLinkFormat: "application/link-format", + AppXML: "application/xml", + AppOctets: "application/octet-stream", + AppExi: "application/exi", + AppJSON: "application/json", + AppJSONPatch: "application/json-patch+json (RFC6902)", + AppJSONMergePatch: "application/merge-patch+json (RFC7396)", + AppCBOR: "application/cbor (RFC 7049)", + AppCWT: "application/cwt", + AppCoseEncrypt: "application/cose; cose-type=\"cose-encrypt\" (RFC 8152)", + AppCoseMac: "application/cose; cose-type=\"cose-mac\" (RFC 8152)", + AppCoseSign: "application/cose; cose-type=\"cose-sign\" (RFC 8152)", + AppCoseKey: "application/cose-key (RFC 8152)", + AppCoseKeySet: "application/cose-key-set (RFC 8152)", + AppSenmlJSON: "application/senml+json", + AppSenmlCbor: "application/senml+cbor", + AppCoapGroup: "coap-group+json (RFC 7390)", + AppSenmlEtchJSON: "application/senml-etch+json", + AppSenmlEtchCbor: "application/senml-etch+cbor", + AppOcfCbor: "application/vnd.ocf+cbor", + AppLwm2mTLV: "application/vnd.oma.lwm2m+tlv", + AppLwm2mJSON: "application/vnd.oma.lwm2m+json", + AppLwm2mCbor: "application/vnd.oma.lwm2m+cbor", +} + +func (c MediaType) String() string { + str, ok := mediaTypeToString[c] + if !ok { + return "MediaType(" + strconv.FormatInt(int64(c), 10) + ")" + } + return str +} + +func ToMediaType(v string) (MediaType, error) { + for key, val := range mediaTypeToString { + if val == v { + return key, nil + } + } + return 0, fmt.Errorf("not found") +} + +func extendOpt(opt int) (int, int) { + ext := 0 + if opt >= ExtendOptionByteAddend { + if opt >= ExtendOptionWordAddend { + ext = opt - ExtendOptionWordAddend + opt = ExtendOptionWordCode + } else { + ext = opt - ExtendOptionByteAddend + opt = ExtendOptionByteCode + } + } + return opt, ext +} + +// VerifyOptLen checks whether valueLen is within (min, max) length limits for given option. +func VerifyOptLen(optID OptionID, valueLen int) bool { + def := CoapOptionDefs[optID] + if valueLen < int(def.MinLen) || valueLen > int(def.MaxLen) { + return false + } + return true +} + +func marshalOptionHeaderExt(buf []byte, opt, ext int) (int, error) { + switch opt { + case ExtendOptionByteCode: + if len(buf) > 0 { + buf[0] = byte(ext) + return 1, nil + } + return 1, ErrTooSmall + case ExtendOptionWordCode: + if len(buf) > 1 { + binary.BigEndian.PutUint16(buf, uint16(ext)) + return 2, nil + } + return 2, ErrTooSmall + } + return 0, nil +} + +func marshalOptionHeader(buf []byte, delta, length int) (int, error) { + size := 0 + + d, dx := extendOpt(delta) + l, lx := extendOpt(length) + + if len(buf) > 0 { + buf[0] = byte(d<<4) | byte(l) + size++ + } else { + buf = nil + size++ + } + var lenBuf int + var err error + if buf == nil { + lenBuf, err = marshalOptionHeaderExt(nil, d, dx) + } else { + lenBuf, err = marshalOptionHeaderExt(buf[size:], d, dx) + } + + switch { + case err == nil: + case errors.Is(err, ErrTooSmall): + buf = nil + default: + return -1, err + } + size += lenBuf + + if buf == nil { + lenBuf, err = marshalOptionHeaderExt(nil, l, lx) + } else { + lenBuf, err = marshalOptionHeaderExt(buf[size:], l, lx) + } + switch { + case err == nil: + case errors.Is(err, ErrTooSmall): + buf = nil + default: + return -1, err + } + size += lenBuf + if buf == nil { + return size, ErrTooSmall + } + return size, nil +} + +type Option struct { + Value []byte + ID OptionID +} + +func (o Option) MarshalValue(buf []byte) (int, error) { + if len(buf) < len(o.Value) { + return len(o.Value), ErrTooSmall + } + copy(buf, o.Value) + return len(o.Value), nil +} + +func (o *Option) UnmarshalValue(buf []byte) (int, error) { + o.Value = buf + return len(buf), nil +} + +func (o Option) Marshal(buf []byte, previousID OptionID) (int, error) { + /* + 0 1 2 3 4 5 6 7 + +---------------+---------------+ + | | | + | Option Delta | Option Length | 1 byte + | | | + +---------------+---------------+ + \ \ + / Option Delta / 0-2 bytes + \ (extended) \ + +-------------------------------+ + \ \ + / Option Length / 0-2 bytes + \ (extended) \ + +-------------------------------+ + \ \ + / / + \ \ + / Option Value / 0 or more bytes + \ \ + / / + \ \ + +-------------------------------+ + */ + delta := int(o.ID) - int(previousID) + + lenBuf, err := o.MarshalValue(nil) + switch { + case err == nil, errors.Is(err, ErrTooSmall): + default: + return -1, err + } + + // header marshal + lenBuf, err = marshalOptionHeader(buf, delta, lenBuf) + switch { + case err == nil: + case errors.Is(err, ErrTooSmall): + buf = nil + default: + return -1, err + } + length := lenBuf + + if buf == nil { + lenBuf, err = o.MarshalValue(nil) + } else { + lenBuf, err = o.MarshalValue(buf[length:]) + } + + switch { + case err == nil: + case errors.Is(err, ErrTooSmall): + buf = nil + default: + return -1, err + } + length += lenBuf + + if buf == nil { + return length, ErrTooSmall + } + return length, nil +} + +func parseExtOpt(data []byte, opt int) (int, int, error) { + processed := 0 + switch opt { + case ExtendOptionByteCode: + if len(data) < 1 { + return 0, -1, ErrOptionTruncated + } + opt = int(data[0]) + ExtendOptionByteAddend + processed = 1 + case ExtendOptionWordCode: + if len(data) < 2 { + return 0, -1, ErrOptionTruncated + } + opt = int(binary.BigEndian.Uint16(data[:2])) + ExtendOptionWordAddend + processed = 2 + } + return processed, opt, nil +} + +func (o *Option) Unmarshal(data []byte, optionDefs map[OptionID]OptionDef, optionID OptionID) (int, error) { + if def, ok := optionDefs[optionID]; ok { + if def.ValueFormat == ValueUnknown { + // Skip unrecognized options (RFC7252 section 5.4.1) + return len(data), nil + } + if uint32(len(data)) < def.MinLen || uint32(len(data)) > def.MaxLen { + // Skip options with illegal value length (RFC7252 section 5.4.3) + return len(data), nil + } + } + o.ID = optionID + proc, err := o.UnmarshalValue(data) + if err != nil { + return -1, err + } + return proc, err +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/options.go b/vendor/github.com/plgd-dev/go-coap/v3/message/options.go new file mode 100644 index 0000000000..b162f9a7c2 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/options.go @@ -0,0 +1,637 @@ +package message + +import ( + "errors" + "strings" +) + +// Options Container of COAP Options, It must be always sort'ed after modification. +type Options []Option + +const maxPathValue = 255 + +// GetPathBufferSize gets the size of the buffer required to store path in URI-Path options. +// +// If the path cannot be stored an error is returned. +func GetPathBufferSize(path string) (int, error) { + size := 0 + for start := 0; start < len(path); { + subPath := path[start:] + segmentSize := strings.Index(subPath, "/") + if segmentSize == 0 { + start++ + continue + } + if segmentSize < 0 { + segmentSize = len(subPath) + } + if segmentSize > maxPathValue { + return -1, ErrInvalidValueLength + } + size += segmentSize + start += segmentSize + 1 + } + return size, nil +} + +func setPath(options Options, optionID OptionID, buf []byte, path string) (Options, int, error) { + if len(path) == 0 { + return options, 0, nil + } + o := options.Remove(optionID) + if path[0] == '/' { + path = path[1:] + } + requiredSize, err := GetPathBufferSize(path) + if err != nil { + return options, -1, err + } + if requiredSize > len(buf) { + return options, -1, ErrTooSmall + } + encoded := 0 + for start := 0; start < len(path); { + subPath := path[start:] + end := strings.Index(subPath, "/") + if end == 0 { + start++ + continue + } + if end < 0 { + end = len(subPath) + } + data := buf[encoded:] + var enc int + var err error + o, enc, err = o.AddString(data, optionID, subPath[:end]) + if err != nil { + return o, -1, err + } + encoded += enc + start += end + 1 + } + return o, encoded, nil +} + +// SetPath splits path by '/' to URIPath options and copies it to buffer. +// +// Returns modified options, number of used buf bytes and error if occurs. +// +// @note the url encoded into URIHost, URIPort, URIPath is expected to be +// absolute (https://www.rfc-editor.org/rfc/rfc7252.txt) +func (options Options) SetPath(buf []byte, path string) (Options, int, error) { + return setPath(options, URIPath, buf, path) +} + +// SetLocationPath splits path by '/' to LocationPath options and copies it to buffer. +// +// Returns modified options, number of used buf bytes and error if occurs. +// +// @note the url encoded into LocationPath is expected to be +// absolute (https://www.rfc-editor.org/rfc/rfc7252.txt) +func (options Options) SetLocationPath(buf []byte, path string) (Options, int, error) { + return setPath(options, LocationPath, buf, path) +} + +func (options Options) path(buf []byte, id OptionID) (int, error) { + firstIdx, lastIdx, err := options.Find(id) + if err != nil { + return -1, err + } + var needed int + for i := firstIdx; i < lastIdx; i++ { + needed += len(options[i].Value) + needed++ + } + + if len(buf) < needed { + return needed, ErrTooSmall + } + for i := firstIdx; i < lastIdx; i++ { + buf[0] = '/' + buf = buf[1:] + + copy(buf, options[i].Value) + buf = buf[len(options[i].Value):] + } + return needed, nil +} + +// Path joins URIPath options by '/' to the buf. +// +// Returns number of used buf bytes or error when occurs. +func (options Options) Path() (string, error) { + buf := make([]byte, 32) + m, err := options.path(buf, URIPath) + if errors.Is(err, ErrTooSmall) { + buf = append(buf, make([]byte, m)...) + m, err = options.path(buf, URIPath) + } + if err != nil { + return "", err + } + buf = buf[:m] + return string(buf), nil +} + +// LocationPath joins Location-Path options by '/' to the buf. +// +// Returns number of used buf bytes or error when occurs. +func (options Options) LocationPath() (string, error) { + buf := make([]byte, 32) + m, err := options.path(buf, LocationPath) + if errors.Is(err, ErrTooSmall) { + buf = append(buf, make([]byte, m)...) + m, err = options.path(buf, LocationPath) + } + if err != nil { + return "", err + } + buf = buf[:m] + return string(buf), nil +} + +// SetString replaces/stores string option to options. +// +// Returns modified options, number of used buf bytes and error if occurs. +func (options Options) SetString(buf []byte, id OptionID, str string) (Options, int, error) { + data := []byte(str) + return options.SetBytes(buf, id, data) +} + +// AddString appends string option to options. +// +// Returns modified options, number of used buf bytes and error if occurs. +func (options Options) AddString(buf []byte, id OptionID, str string) (Options, int, error) { + data := []byte(str) + return options.AddBytes(buf, id, data) +} + +// HasOption returns true is option exist in options. +func (options Options) HasOption(id OptionID) bool { + _, _, err := options.Find(id) + return err == nil +} + +// GetUint32s gets all options with same id. +func (options Options) GetUint32s(id OptionID, r []uint32) (int, error) { + firstIdx, lastIdx, err := options.Find(id) + if err != nil { + return 0, err + } + if len(r) < lastIdx-firstIdx { + return lastIdx - firstIdx, ErrTooSmall + } + var idx int + for i := firstIdx; i <= lastIdx; i++ { + val, _, err := DecodeUint32(options[i].Value) + if err == nil { + r[idx] = val + idx++ + } + } + + return idx, nil +} + +// GetUint32 gets the uin32 value of the first option with the given ID. +func (options Options) GetUint32(id OptionID) (uint32, error) { + firstIdx, _, err := options.Find(id) + if err != nil { + return 0, err + } + val, _, err := DecodeUint32(options[firstIdx].Value) + return val, err +} + +// ContentFormat gets the content format of body. +func (options Options) ContentFormat() (MediaType, error) { + v, err := options.GetUint32(ContentFormat) + return MediaType(v), err +} + +// GetString gets the string value of the first option with the given ID. +func (options Options) GetString(id OptionID) (string, error) { + firstIdx, _, err := options.Find(id) + if err != nil { + return "", err + } + return string(options[firstIdx].Value), nil +} + +// GetStrings gets string array of all options with the given id. +func (options Options) GetStrings(id OptionID, r []string) (int, error) { + firstIdx, lastIdx, err := options.Find(id) + if err != nil { + return 0, err + } + if len(r) < lastIdx-firstIdx { + return lastIdx - firstIdx, ErrTooSmall + } + var idx int + for i := firstIdx; i < lastIdx; i++ { + r[idx] = string(options[i].Value) + idx++ + } + + return idx, nil +} + +// Queries gets URIQuery parameters. +func (options Options) Queries() ([]string, error) { + q := make([]string, 4) + n, err := options.GetStrings(URIQuery, q) + if errors.Is(err, ErrTooSmall) { + q = append(q, make([]string, n-len(q))...) + n, err = options.GetStrings(URIQuery, q) + } + if err != nil { + return nil, err + } + return q[:n], nil +} + +// SetBytes replaces/stores bytes of a option to options. +// +// Returns modified options, number of used buf bytes and error if occurs. +func (options Options) SetBytes(buf []byte, id OptionID, data []byte) (Options, int, error) { + if len(buf) < len(data) { + return options, len(data), ErrTooSmall + } + if id == URIPath && len(data) > maxPathValue { + return options, -1, ErrInvalidValueLength + } + copy(buf, data) + return options.Set(Option{ID: id, Value: buf[:len(data)]}), len(data), nil +} + +// AddBytes appends bytes of a option option to options. +// +// Returns modified options, number of used buf bytes and error if occurs. +func (options Options) AddBytes(buf []byte, id OptionID, data []byte) (Options, int, error) { + if len(buf) < len(data) { + return options, len(data), ErrTooSmall + } + if id == URIPath && len(data) > maxPathValue { + return options, -1, ErrInvalidValueLength + } + copy(buf, data) + return options.Add(Option{ID: id, Value: buf[:len(data)]}), len(data), nil +} + +// GetBytes gets bytes of the first option with given id. +func (options Options) GetBytes(id OptionID) ([]byte, error) { + firstIdx, _, err := options.Find(id) + if err != nil { + return nil, err + } + return options[firstIdx].Value, nil +} + +// GetBytess gets array of bytes of all options with the same id. +func (options Options) GetBytess(id OptionID, r [][]byte) (int, error) { + firstIdx, lastIdx, err := options.Find(id) + if err != nil { + return 0, err + } + if len(r) < lastIdx-firstIdx { + return lastIdx - firstIdx, ErrTooSmall + } + var idx int + for i := firstIdx; i < lastIdx; i++ { + r[idx] = options[i].Value + idx++ + } + + return idx, nil +} + +// AddUint32 appends uint32 option to options. +// +// Returns modified options, number of used buf bytes and error if occurs. +func (options Options) AddUint32(buf []byte, id OptionID, value uint32) (Options, int, error) { + enc, err := EncodeUint32(buf, value) + if err != nil { + return options, enc, err + } + o := options.Add(Option{ID: id, Value: buf[:enc]}) + return o, enc, err +} + +// SetUint32 replaces/stores uint32 option to options. +// +// Returns modified options, number of used buf bytes and error if occurs. +func (options Options) SetUint32(buf []byte, id OptionID, value uint32) (Options, int, error) { + enc, err := EncodeUint32(buf, value) + if err != nil { + return options, enc, err + } + o := options.Set(Option{ID: id, Value: buf[:enc]}) + return o, enc, err +} + +// SetContentFormat sets ContentFormat option. +func (options Options) SetContentFormat(buf []byte, contentFormat MediaType) (Options, int, error) { + return options.SetUint32(buf, ContentFormat, uint32(contentFormat)) +} + +// SetObserve sets Observe option. +func (options Options) SetObserve(buf []byte, observe uint32) (Options, int, error) { + return options.SetUint32(buf, Observe, observe) +} + +// Observe gets Observe option. +func (options Options) Observe() (uint32, error) { + return options.GetUint32(Observe) +} + +// SetAccept sets accept option. +func (options Options) SetAccept(buf []byte, contentFormat MediaType) (Options, int, error) { + return options.SetUint32(buf, Accept, uint32(contentFormat)) +} + +// Accept gets accept option. +func (options Options) Accept() (MediaType, error) { + v, err := options.GetUint32(Accept) + return MediaType(v), err +} + +// Find returns range of type options. First number is index and second number is index of next option type. +func (options Options) Find(id OptionID) (int, int, error) { + idxPre, idxPost := options.findPosition(id) + if idxPre == -1 && idxPost == 0 { + return -1, -1, ErrOptionNotFound + } + if idxPre == len(options)-1 && idxPost == -1 { + return -1, -1, ErrOptionNotFound + } + if idxPre < idxPost && idxPost-idxPre == 1 { + return -1, -1, ErrOptionNotFound + } + idxPre++ + if idxPost < 0 { + idxPost = len(options) + } + return idxPre, idxPost, nil +} + +// findPosition returns opened interval, -1 at means minIdx insert at 0, -1 maxIdx at maxIdx means append. +func (options Options) findPosition(id OptionID) (minIdx int, maxIdx int) { + if len(options) == 0 { + return -1, 0 + } + pivot := 0 + maxIdx = len(options) + minIdx = 0 + for { + switch { + case id == options[pivot].ID || (maxIdx-minIdx)/2 == 0: + for maxIdx = pivot; maxIdx < len(options) && options[maxIdx].ID <= id; { + maxIdx++ + } + if maxIdx == len(options) { + maxIdx = -1 + } + for minIdx = pivot; minIdx >= 0 && options[minIdx].ID >= id; { + minIdx-- + } + return minIdx, maxIdx + case id < options[pivot].ID: + maxIdx = pivot + pivot = maxIdx - (maxIdx-minIdx)/2 + case id > options[pivot].ID: + minIdx = pivot + pivot = minIdx + (maxIdx-minIdx)/2 + } + } +} + +// Set replaces/stores option at options. +// +// Returns modified options. +func (options Options) Set(opt Option) Options { + idxPre, idxPost := options.findPosition(opt.ID) + if idxPre == -1 && idxPost == -1 { + // append + options = append(options[:0], opt) + return options + } + var insertPosition int + var updateTo int + var updateFrom int + optsLength := len(options) + switch { + case idxPre == -1 && idxPost >= 0: + insertPosition = 0 + updateTo = 1 + updateFrom = idxPost + case idxPre == idxPost: + insertPosition = idxPre + updateFrom = idxPre + updateTo = idxPre + 1 + case idxPre >= 0: + insertPosition = idxPre + 1 + updateTo = idxPre + 2 + updateFrom = idxPost + if updateFrom < 0 { + updateFrom = len(options) + } + if updateTo == updateFrom { + options[insertPosition] = opt + return options + } + } + if len(options) == cap(options) { + options = append(options, Option{}) + } else { + options = options[:len(options)+1] + } + // replace + move + updateIdx := updateTo + if updateFrom < updateTo { + for i := optsLength; i > updateFrom; i-- { + options[i] = options[i-1] + updateIdx++ + } + } else { + for i := updateFrom; i < optsLength; i++ { + options[updateIdx] = options[i] + updateIdx++ + } + } + options[insertPosition] = opt + options = options[:updateIdx] + + return options +} + +// Add appends option to options. +func (options Options) Add(opt Option) Options { + _, idxPost := options.findPosition(opt.ID) + if idxPost == -1 { + idxPost = len(options) + } + if len(options) == cap(options) { + options = append(options, Option{}) + } else { + options = options[:len(options)+1] + } + for i := len(options) - 1; i > idxPost; i-- { + options[i] = options[i-1] + } + options[idxPost] = opt + return options +} + +// Remove removes all options with ID. +func (options Options) Remove(id OptionID) Options { + idxPre, idxPost, err := options.Find(id) + if err != nil { + return options + } + updateIdx := idxPre + for i := idxPost; i < len(options); i++ { + options[updateIdx] = options[i] + updateIdx++ + } + length := len(options) - (idxPost - idxPre) + options = options[:length] + + return options +} + +// Marshal marshals options to buf. +// +// Returns the number of used buf bytes. +func (options Options) Marshal(buf []byte) (int, error) { + previousID := OptionID(0) + length := 0 + + for _, o := range options { + // return coap.error but calculate length + if length > len(buf) { + buf = nil + } + + var optionLength int + var err error + + if buf != nil { + optionLength, err = o.Marshal(buf[length:], previousID) + } else { + optionLength, err = o.Marshal(nil, previousID) + } + previousID = o.ID + + switch { + case err == nil: + case errors.Is(err, ErrTooSmall): + buf = nil + default: + return -1, err + } + length += optionLength + } + if buf == nil { + return length, ErrTooSmall + } + return length, nil +} + +// Unmarshal unmarshals data bytes to options and returns the number of consumed bytes. +func (options *Options) Unmarshal(data []byte, optionDefs map[OptionID]OptionDef) (int, error) { + prev := 0 + processed := 0 + for len(data) > 0 { + if data[0] == 0xff { + processed++ + break + } + + delta := int(data[0] >> 4) + length := int(data[0] & 0x0f) + + if delta == ExtendOptionError || length == ExtendOptionError { + return -1, ErrOptionUnexpectedExtendMarker + } + + data = data[1:] + processed++ + + proc, delta, err := parseExtOpt(data, delta) + if err != nil { + return -1, err + } + processed += proc + data = data[proc:] + proc, length, err = parseExtOpt(data, length) + if err != nil { + return -1, err + } + processed += proc + data = data[proc:] + + if len(data) < length { + return -1, ErrOptionTruncated + } + + option := Option{} + oid := OptionID(prev + delta) + proc, err = option.Unmarshal(data[:length], optionDefs, oid) + if err != nil { + return -1, err + } + + if cap(*options) == len(*options) { + return -1, ErrOptionsTooSmall + } + if option.ID != 0 { + (*options) = append(*options, option) + } + + processed += proc + data = data[proc:] + prev = int(oid) + } + + return processed, nil +} + +// ResetOptionsTo resets options to in options. +// +// Returns modified options, number of used buf bytes and error if occurs. +func (options Options) ResetOptionsTo(buf []byte, in Options) (Options, int, error) { + opts := options[:0] + used := 0 + for idx, o := range in { + if len(buf) < len(o.Value) { + for i := idx; i < len(in); i++ { + used += len(in[i].Value) + } + return options, used, ErrTooSmall + } + copy(buf, o.Value) + used += len(o.Value) + opts = opts.Add(Option{ + ID: o.ID, + Value: buf[:len(o.Value)], + }) + buf = buf[len(o.Value):] + } + return opts, used, nil +} + +// Clone create duplicates of options. +func (options Options) Clone() (Options, error) { + opts := make(Options, 0, len(options)) + buf := make([]byte, 64) + opts, used, err := opts.ResetOptionsTo(buf, options) + if errors.Is(err, ErrTooSmall) { + buf = append(buf, make([]byte, used-len(buf))...) + opts, _, err = opts.ResetOptionsTo(buf, options) + } + if err != nil { + return nil, err + } + return opts, nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/pool/message.go b/vendor/github.com/plgd-dev/go-coap/v3/message/pool/message.go new file mode 100644 index 0000000000..22cf886524 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/pool/message.go @@ -0,0 +1,600 @@ +package pool + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + + multierror "github.com/hashicorp/go-multierror" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "go.uber.org/atomic" +) + +type Encoder interface { + Size(m message.Message) (int, error) + Encode(m message.Message, buf []byte) (int, error) +} + +type Decoder interface { + Decode(buf []byte, m *message.Message) (int, error) +} + +type Message struct { + // Context context of request. + ctx context.Context + msg message.Message + hijacked atomic.Bool + isModified bool + valueBuffer []byte + origValueBuffer []byte + body io.ReadSeeker + sequence uint64 + + // local vars + bufferUnmarshal []byte + bufferMarshal []byte +} + +const valueBufferSize = 256 + +func NewMessage(ctx context.Context) *Message { + valueBuffer := make([]byte, valueBufferSize) + return &Message{ + ctx: ctx, + msg: message.Message{ + Options: make(message.Options, 0, 16), + MessageID: -1, + Type: message.Unset, + }, + valueBuffer: valueBuffer, + origValueBuffer: valueBuffer, + bufferUnmarshal: make([]byte, 256), + bufferMarshal: make([]byte, 256), + } +} + +func (r *Message) Context() context.Context { + return r.ctx +} + +func (r *Message) SetContext(ctx context.Context) { + r.ctx = ctx +} + +func (r *Message) SetMessage(message message.Message) { + r.Reset() + r.msg = message + if len(message.Payload) > 0 { + r.body = bytes.NewReader(message.Payload) + } + r.isModified = true +} + +// SetMessageID only 0 to 2^16-1 are valid. +func (r *Message) SetMessageID(mid int32) { + r.msg.MessageID = mid + r.isModified = true +} + +// UpsertMessageID set value only when origin value is invalid. Only 0 to 2^16-1 values are valid. +func (r *Message) UpsertMessageID(mid int32) { + if message.ValidateMID(r.msg.MessageID) { + return + } + r.SetMessageID(mid) +} + +// MessageID returns 0 to 2^16-1 otherwise it contains invalid value. +func (r *Message) MessageID() int32 { + return r.msg.MessageID +} + +func (r *Message) SetType(typ message.Type) { + r.msg.Type = typ + r.isModified = true +} + +// UpsertType set value only when origin value is invalid. Only 0 to 2^8-1 values are valid. +func (r *Message) UpsertType(typ message.Type) { + if message.ValidateType(r.msg.Type) { + return + } + r.SetType(typ) +} + +func (r *Message) Type() message.Type { + return r.msg.Type +} + +// Reset clear message for next reuse +func (r *Message) Reset() { + r.msg.Token = nil + r.msg.Code = codes.Empty + r.msg.Options = r.msg.Options[:0] + r.msg.MessageID = -1 + r.msg.Type = message.Unset + r.msg.Payload = nil + r.valueBuffer = r.origValueBuffer + r.body = nil + r.isModified = false + if cap(r.bufferMarshal) > 1024 { + r.bufferMarshal = make([]byte, 256) + } + if cap(r.bufferUnmarshal) > 1024 { + r.bufferUnmarshal = make([]byte, 256) + } + r.isModified = false +} + +func (r *Message) Path() (string, error) { + return r.msg.Options.Path() +} + +func (r *Message) Queries() ([]string, error) { + return r.msg.Options.Queries() +} + +func (r *Message) Remove(opt message.OptionID) { + r.msg.Options = r.msg.Options.Remove(opt) + r.isModified = true +} + +func (r *Message) Token() message.Token { + if r.msg.Token == nil { + return nil + } + token := make(message.Token, 0, 8) + token = append(token, r.msg.Token...) + return token +} + +func (r *Message) SetToken(token message.Token) { + if token == nil { + r.msg.Token = nil + return + } + r.msg.Token = append(r.msg.Token[:0], token...) +} + +func (r *Message) ResetOptionsTo(in message.Options) { + opts, used, err := r.msg.Options.ResetOptionsTo(r.valueBuffer, in) + if errors.Is(err, message.ErrTooSmall) { + r.valueBuffer = append(r.valueBuffer, make([]byte, used)...) + opts, used, err = r.msg.Options.ResetOptionsTo(r.valueBuffer, in) + } + if err != nil { + panic(fmt.Errorf("cannot reset options to: %w", err)) + } + r.msg.Options = opts + r.valueBuffer = r.valueBuffer[used:] + if len(in) > 0 { + r.isModified = true + } +} + +func (r *Message) Options() message.Options { + return r.msg.Options +} + +// SetPath stores the given path within URI-Path options. +// +// The value is stored by the algorithm described in RFC7252 and +// using the internal buffer. If the path is too long, but valid +// (URI-Path segments must have maximal length of 255) the internal +// buffer is expanded. +// If the path is too long, but not valid then the function returns +// ErrInvalidValueLength error. +func (r *Message) SetPath(p string) error { + opts, used, err := r.msg.Options.SetPath(r.valueBuffer, p) + if errors.Is(err, message.ErrTooSmall) { + expandBy, errSize := message.GetPathBufferSize(p) + if errSize != nil { + return fmt.Errorf("cannot calculate buffer size for path: %w", errSize) + } + r.valueBuffer = append(r.valueBuffer, make([]byte, expandBy)...) + opts, used, err = r.msg.Options.SetPath(r.valueBuffer, p) + } + if err != nil { + return fmt.Errorf("cannot set path: %w", err) + } + r.msg.Options = opts + r.valueBuffer = r.valueBuffer[used:] + r.isModified = true + return nil +} + +// MustSetPath calls SetPath and panics if it returns an error. +func (r *Message) MustSetPath(p string) { + if err := r.SetPath(p); err != nil { + panic(err) + } +} + +func (r *Message) Code() codes.Code { + return r.msg.Code +} + +func (r *Message) SetCode(code codes.Code) { + r.msg.Code = code + r.isModified = true +} + +// AddETag appends value to existing ETags. +// +// Option definition: +// - format: opaque, length: 1-8, repeatable +func (r *Message) AddETag(value []byte) error { + if !message.VerifyOptLen(message.ETag, len(value)) { + return message.ErrInvalidValueLength + } + r.AddOptionBytes(message.ETag, value) + return nil +} + +// SetETag inserts/replaces ETag option(s). +// +// After a successful call only a single ETag value will remain. +func (r *Message) SetETag(value []byte) error { + if !message.VerifyOptLen(message.ETag, len(value)) { + return message.ErrInvalidValueLength + } + r.SetOptionBytes(message.ETag, value) + return nil +} + +// ETag returns first ETag value +func (r *Message) ETag() ([]byte, error) { + return r.GetOptionBytes(message.ETag) +} + +// ETags returns all ETag values +// +// Writes ETag values to output array, returns number of written values or error. +func (r *Message) ETags(b [][]byte) (int, error) { + return r.GetOptionAllBytes(message.ETag, b) +} + +func (r *Message) AddQuery(query string) { + r.AddOptionString(message.URIQuery, query) +} + +func (r *Message) GetOptionUint32(id message.OptionID) (uint32, error) { + return r.msg.Options.GetUint32(id) +} + +func (r *Message) SetOptionString(opt message.OptionID, value string) { + opts, used, err := r.msg.Options.SetString(r.valueBuffer, opt, value) + if errors.Is(err, message.ErrTooSmall) { + r.valueBuffer = append(r.valueBuffer, make([]byte, used)...) + opts, used, err = r.msg.Options.SetString(r.valueBuffer, opt, value) + } + if err != nil { + panic(fmt.Errorf("cannot set string option: %w", err)) + } + r.msg.Options = opts + r.valueBuffer = r.valueBuffer[used:] + r.isModified = true +} + +func (r *Message) AddOptionString(opt message.OptionID, value string) { + opts, used, err := r.msg.Options.AddString(r.valueBuffer, opt, value) + if errors.Is(err, message.ErrTooSmall) { + r.valueBuffer = append(r.valueBuffer, make([]byte, used)...) + opts, used, err = r.msg.Options.AddString(r.valueBuffer, opt, value) + } + if err != nil { + panic(fmt.Errorf("cannot add string option: %w", err)) + } + r.msg.Options = opts + r.valueBuffer = r.valueBuffer[used:] + r.isModified = true +} + +func (r *Message) AddOptionBytes(opt message.OptionID, value []byte) { + if len(r.valueBuffer) < len(value) { + r.valueBuffer = append(r.valueBuffer, make([]byte, len(value)-len(r.valueBuffer))...) + } + n := copy(r.valueBuffer, value) + v := r.valueBuffer[:n] + r.msg.Options = r.msg.Options.Add(message.Option{ID: opt, Value: v}) + r.valueBuffer = r.valueBuffer[n:] + r.isModified = true +} + +func (r *Message) SetOptionBytes(opt message.OptionID, value []byte) { + if len(r.valueBuffer) < len(value) { + r.valueBuffer = append(r.valueBuffer, make([]byte, len(value)-len(r.valueBuffer))...) + } + n := copy(r.valueBuffer, value) + v := r.valueBuffer[:n] + r.msg.Options = r.msg.Options.Set(message.Option{ID: opt, Value: v}) + r.valueBuffer = r.valueBuffer[n:] + r.isModified = true +} + +// GetOptionBytes gets bytes of the first option with given ID. +func (r *Message) GetOptionBytes(id message.OptionID) ([]byte, error) { + return r.msg.Options.GetBytes(id) +} + +// GetOptionAllBytes gets array of bytes of all options with given ID. +func (r *Message) GetOptionAllBytes(id message.OptionID, b [][]byte) (int, error) { + return r.msg.Options.GetBytess(id, b) +} + +func (r *Message) SetOptionUint32(opt message.OptionID, value uint32) { + opts, used, err := r.msg.Options.SetUint32(r.valueBuffer, opt, value) + if errors.Is(err, message.ErrTooSmall) { + r.valueBuffer = append(r.valueBuffer, make([]byte, used)...) + opts, used, err = r.msg.Options.SetUint32(r.valueBuffer, opt, value) + } + if err != nil { + panic(fmt.Errorf("cannot set uint32 option: %w", err)) + } + r.msg.Options = opts + r.valueBuffer = r.valueBuffer[used:] + r.isModified = true +} + +func (r *Message) AddOptionUint32(opt message.OptionID, value uint32) { + opts, used, err := r.msg.Options.AddUint32(r.valueBuffer, opt, value) + if errors.Is(err, message.ErrTooSmall) { + r.valueBuffer = append(r.valueBuffer, make([]byte, used)...) + opts, used, err = r.msg.Options.AddUint32(r.valueBuffer, opt, value) + } + if err != nil { + panic(fmt.Errorf("cannot add uint32 option: %w", err)) + } + r.msg.Options = opts + r.valueBuffer = r.valueBuffer[used:] + r.isModified = true +} + +func (r *Message) ContentFormat() (message.MediaType, error) { + v, err := r.GetOptionUint32(message.ContentFormat) + return message.MediaType(v), err +} + +func (r *Message) HasOption(id message.OptionID) bool { + return r.msg.Options.HasOption(id) +} + +func (r *Message) SetContentFormat(contentFormat message.MediaType) { + r.SetOptionUint32(message.ContentFormat, uint32(contentFormat)) +} + +func (r *Message) SetObserve(observe uint32) { + r.SetOptionUint32(message.Observe, observe) +} + +func (r *Message) Observe() (uint32, error) { + return r.GetOptionUint32(message.Observe) +} + +// SetAccept set's accept option. +func (r *Message) SetAccept(contentFormat message.MediaType) { + r.SetOptionUint32(message.Accept, uint32(contentFormat)) +} + +// Accept get's accept option. +func (r *Message) Accept() (message.MediaType, error) { + v, err := r.GetOptionUint32(message.Accept) + return message.MediaType(v), err +} + +func (r *Message) BodySize() (int64, error) { + if r.body == nil { + return 0, nil + } + orig, err := r.body.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + _, err = r.body.Seek(0, io.SeekStart) + if err != nil { + return 0, err + } + size, err := r.body.Seek(0, io.SeekEnd) + if err != nil { + return 0, err + } + _, err = r.body.Seek(orig, io.SeekStart) + if err != nil { + return 0, err + } + return size, nil +} + +func (r *Message) SetBody(s io.ReadSeeker) { + r.body = s + r.isModified = true +} + +func (r *Message) Body() io.ReadSeeker { + return r.body +} + +func (r *Message) SetSequence(seq uint64) { + r.sequence = seq +} + +func (r *Message) Sequence() uint64 { + return r.sequence +} + +func (r *Message) Hijack() { + r.hijacked.Store(true) +} + +func (r *Message) IsHijacked() bool { + return r.hijacked.Load() +} + +func (r *Message) IsModified() bool { + return r.isModified +} + +func (r *Message) SetModified(b bool) { + r.isModified = b +} + +func (r *Message) String() string { + return r.msg.String() +} + +func (r *Message) ReadBody() ([]byte, error) { + if r.Body() == nil { + return nil, nil + } + size, err := r.BodySize() + if err != nil { + return nil, err + } + if size == 0 { + return nil, nil + } + _, err = r.Body().Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + payload := make([]byte, 1024) + if int64(len(payload)) < size { + payload = make([]byte, size) + } + n, err := io.ReadFull(r.Body(), payload) + if (errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF)) && int64(n) == size { + err = nil + } + if err != nil { + return nil, err + } + return payload[:n], nil +} + +func (r *Message) toMessage() (message.Message, error) { + payload, err := r.ReadBody() + if err != nil { + return message.Message{}, err + } + m := r.msg + m.Payload = payload + return m, nil +} + +func (r *Message) MarshalWithEncoder(encoder Encoder) ([]byte, error) { + msg, err := r.toMessage() + if err != nil { + return nil, err + } + size, err := encoder.Size(msg) + if err != nil { + return nil, err + } + if len(r.bufferMarshal) < size { + r.bufferMarshal = append(r.bufferMarshal, make([]byte, size-len(r.bufferMarshal))...) + } + n, err := encoder.Encode(msg, r.bufferMarshal) + if err != nil { + return nil, err + } + r.bufferMarshal = r.bufferMarshal[:n] + return r.bufferMarshal, nil +} + +func (r *Message) UnmarshalWithDecoder(decoder Decoder, data []byte) (int, error) { + if len(r.bufferUnmarshal) < len(data) { + r.bufferUnmarshal = append(r.bufferUnmarshal, make([]byte, len(data)-len(r.bufferUnmarshal))...) + } + copy(r.bufferUnmarshal, data) + r.body = nil + r.bufferUnmarshal = r.bufferUnmarshal[:len(data)] + n, err := decoder.Decode(r.bufferUnmarshal, &r.msg) + if err != nil { + return n, err + } + if len(r.msg.Payload) > 0 { + r.body = bytes.NewReader(r.msg.Payload) + } + return n, err +} + +func (r *Message) IsSeparateMessage() bool { + return r.Code() == codes.Empty && r.Token() == nil && r.Type() == message.Acknowledgement && len(r.Options()) == 0 && r.Body() == nil +} + +func (r *Message) setupCommon(code codes.Code, path string, token message.Token, opts ...message.Option) error { + r.SetCode(code) + r.SetToken(token) + r.ResetOptionsTo(opts) + return r.SetPath(path) +} + +func (r *Message) SetupGet(path string, token message.Token, opts ...message.Option) error { + return r.setupCommon(codes.GET, path, token, opts...) +} + +func (r *Message) SetupPost(path string, token message.Token, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) error { + if err := r.setupCommon(codes.POST, path, token, opts...); err != nil { + return err + } + if payload != nil { + r.SetContentFormat(contentFormat) + r.SetBody(payload) + } + return nil +} + +func (r *Message) SetupPut(path string, token message.Token, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) error { + if err := r.setupCommon(codes.PUT, path, token, opts...); err != nil { + return err + } + if payload != nil { + r.SetContentFormat(contentFormat) + r.SetBody(payload) + } + return nil +} + +func (r *Message) SetupDelete(path string, token message.Token, opts ...message.Option) error { + return r.setupCommon(codes.DELETE, path, token, opts...) +} + +func (r *Message) Clone(msg *Message) error { + msg.SetCode(r.Code()) + msg.SetToken(r.Token()) + msg.ResetOptionsTo(r.Options()) + msg.SetType(r.Type()) + msg.SetMessageID(r.MessageID()) + + if r.Body() != nil { + buf := bytes.NewBuffer(nil) + n, err := r.Body().Seek(0, io.SeekCurrent) + if err != nil { + return err + } + _, err = r.body.Seek(0, io.SeekStart) + if err != nil { + return err + } + _, err = io.Copy(buf, r.Body()) + if err != nil { + var errs *multierror.Error + errs = multierror.Append(errs, err) + _, errS := r.Body().Seek(n, io.SeekStart) + if errS != nil { + errs = multierror.Append(errs, errS) + } + return errs.ErrorOrNil() + } + _, err = r.Body().Seek(n, io.SeekStart) + if err != nil { + return err + } + r := bytes.NewReader(buf.Bytes()) + msg.SetBody(r) + } + return nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/pool/pool.go b/vendor/github.com/plgd-dev/go-coap/v3/message/pool/pool.go new file mode 100644 index 0000000000..e9c91a2921 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/pool/pool.go @@ -0,0 +1,64 @@ +package pool + +import ( + "context" + "fmt" + "sync" + + "go.uber.org/atomic" +) + +type Pool struct { + // This field needs to be the first in the struct to ensure proper word alignment on 32-bit platforms. + // See: https://golang.org/pkg/sync/atomic/#pkg-note-BUG + currentMessagesInPool atomic.Int64 + messagePool sync.Pool + maxNumMessages uint32 + maxMessageBufferSize uint16 +} + +func New(maxNumMessages uint32, maxMessageBufferSize uint16) *Pool { + return &Pool{ + maxNumMessages: maxNumMessages, + maxMessageBufferSize: maxMessageBufferSize, + } +} + +// AcquireMessage returns an empty Message instance from Message pool. +// +// The returned Message instance may be passed to ReleaseMessage when it is +// no longer needed. This allows Message recycling, reduces GC pressure +// and usually improves performance. +func (p *Pool) AcquireMessage(ctx context.Context) *Message { + v := p.messagePool.Get() + if v == nil { + return NewMessage(ctx) + } + r, ok := v.(*Message) + if !ok { + panic(fmt.Errorf("invalid message type(%T) for pool", v)) + } + p.currentMessagesInPool.Dec() + r.ctx = ctx + return r +} + +// ReleaseMessage returns req acquired via AcquireMessage to Message pool. +// +// It is forbidden accessing req and/or its' members after returning +// it to Message pool. +func (p *Pool) ReleaseMessage(req *Message) { + for { + v := p.currentMessagesInPool.Load() + if v >= int64(p.maxNumMessages) { + return + } + next := v + 1 + if p.currentMessagesInPool.CompareAndSwap(v, next) { + break + } + } + req.Reset() + req.ctx = nil + p.messagePool.Put(req) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/tcpOptions.go b/vendor/github.com/plgd-dev/go-coap/v3/message/tcpOptions.go new file mode 100644 index 0000000000..b6d6a740aa --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/tcpOptions.go @@ -0,0 +1,78 @@ +package message + +// Signal CSM Option IDs +/* + +-----+---+---+-------------------+--------+--------+---------+ + | No. | C | R | Name | Format | Length | Default | + +-----+---+---+-------------------+--------+--------+---------+ + | 2 | | | MaxMessageSize | uint | 0-4 | 1152 | + | 4 | | | BlockWiseTransfer | empty | 0 | (none) | + +-----+---+---+-------------------+--------+--------+---------+ + C=Critical, R=Repeatable +*/ + +const ( + TCPMaxMessageSize OptionID = 2 + TCPBlockWiseTransfer OptionID = 4 +) + +// Signal Ping/Pong Option IDs +/* + +-----+---+---+-------------------+--------+--------+---------+ + | No. | C | R | Name | Format | Length | Default | + +-----+---+---+-------------------+--------+--------+---------+ + | 2 | | | Custody | empty | 0 | (none) | + +-----+---+---+-------------------+--------+--------+---------+ + C=Critical, R=Repeatable +*/ + +const ( + TCPCustody OptionID = 2 +) + +// Signal Release Option IDs +/* + +-----+---+---+---------------------+--------+--------+---------+ + | No. | C | R | Name | Format | Length | Default | + +-----+---+---+---------------------+--------+--------+---------+ + | 2 | | x | Alternative-Address | string | 1-255 | (none) | + | 4 | | | Hold-Off | uint3 | 0-3 | (none) | + +-----+---+---+---------------------+--------+--------+---------+ + C=Critical, R=Repeatable +*/ + +const ( + TCPAlternativeAddress OptionID = 2 + TCPHoldOff OptionID = 4 +) + +// Signal Abort Option IDs +/* + +-----+---+---+---------------------+--------+--------+---------+ + | No. | C | R | Name | Format | Length | Default | + +-----+---+---+---------------------+--------+--------+---------+ + | 2 | | | Bad-CSM-Option | uint | 0-2 | (none) | + +-----+---+---+---------------------+--------+--------+---------+ + C=Critical, R=Repeatable +*/ +const ( + TCPBadCSMOption OptionID = 2 +) + +var TCPSignalCSMOptionDefs = map[OptionID]OptionDef{ + TCPMaxMessageSize: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 4}, + TCPBlockWiseTransfer: {ValueFormat: ValueEmpty, MinLen: 0, MaxLen: 0}, +} + +var TCPSignalPingPongOptionDefs = map[OptionID]OptionDef{ + TCPCustody: {ValueFormat: ValueEmpty, MinLen: 0, MaxLen: 0}, +} + +var TCPSignalReleaseOptionDefs = map[OptionID]OptionDef{ + TCPAlternativeAddress: {ValueFormat: ValueString, MinLen: 1, MaxLen: 255}, + TCPHoldOff: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 3}, +} + +var TCPSignalAbortOptionDefs = map[OptionID]OptionDef{ + TCPBadCSMOption: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 2}, +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/message/type.go b/vendor/github.com/plgd-dev/go-coap/v3/message/type.go new file mode 100644 index 0000000000..1c0e0ea505 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/message/type.go @@ -0,0 +1,45 @@ +package message + +import ( + "math" + "strconv" +) + +// Type represents the message type. +// It's only part of CoAP UDP messages. +// Reliable transports like TCP do not have a type. +type Type int16 + +const ( + // Used for unset + Unset Type = -1 + // Confirmable messages require acknowledgements. + Confirmable Type = 0 + // NonConfirmable messages do not require acknowledgements. + NonConfirmable Type = 1 + // Acknowledgement is a message indicating a response to confirmable message. + Acknowledgement Type = 2 + // Reset indicates a permanent negative acknowledgement. + Reset Type = 3 +) + +var typeToString = map[Type]string{ + Unset: "Unset", + Confirmable: "Confirmable", + NonConfirmable: "NonConfirmable", + Acknowledgement: "Acknowledgement", + Reset: "Reset", +} + +func (t Type) String() string { + val, ok := typeToString[t] + if ok { + return val + } + return "Type(" + strconv.FormatInt(int64(t), 10) + ")" +} + +// ValidateType validates the type for UDP. (0 <= typ <= 255) +func ValidateType(typ Type) bool { + return typ >= 0 && typ <= math.MaxUint8 +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/mux/client.go b/vendor/github.com/plgd-dev/go-coap/v3/mux/client.go new file mode 100644 index 0000000000..2a82077e51 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/mux/client.go @@ -0,0 +1,51 @@ +package mux + +import ( + "context" + "io" + "net" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type Observation = interface { + Cancel(ctx context.Context, opts ...message.Option) error + Canceled() bool +} + +type Conn interface { + // create message from pool + AcquireMessage(ctx context.Context) *pool.Message + // return back the message to the pool for next use + ReleaseMessage(m *pool.Message) + + Ping(ctx context.Context) error + Get(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) + Delete(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) + Post(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) + Put(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) + Observe(ctx context.Context, path string, observeFunc func(notification *pool.Message), opts ...message.Option) (Observation, error) + + RemoteAddr() net.Addr + // NetConn returns the underlying connection that is wrapped by client. The Conn returned is shared by all invocations of NetConn, so do not modify it. + NetConn() net.Conn + Context() context.Context + SetContextValue(key interface{}, val interface{}) + WriteMessage(req *pool.Message) error + // used for GET,PUT,POST,DELETE + Do(req *pool.Message) (*pool.Message, error) + // used for observation (GET with observe 0) + DoObserve(req *pool.Message, observeFunc func(req *pool.Message)) (Observation, error) + Close() error + Sequence() uint64 + // Done signalizes that connection is not more processed. + Done() <-chan struct{} + AddOnClose(func()) + + NewGetRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) + NewObserveRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) + NewPutRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) + NewPostRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) + NewDeleteRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/mux/message.go b/vendor/github.com/plgd-dev/go-coap/v3/mux/message.go new file mode 100644 index 0000000000..bba3c2cfd4 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/mux/message.go @@ -0,0 +1,16 @@ +package mux + +import "github.com/plgd-dev/go-coap/v3/message/pool" + +// RouteParams contains all the information related to a route +type RouteParams struct { + Path string + Vars map[string]string + PathTemplate string +} + +// Message contains message with sequence number. +type Message struct { + *pool.Message + RouteParams *RouteParams +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/mux/middleware.go b/vendor/github.com/plgd-dev/go-coap/v3/mux/middleware.go new file mode 100644 index 0000000000..94ad7ba928 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/mux/middleware.go @@ -0,0 +1,16 @@ +package mux + +// MiddlewareFunc is a function which receives an Handler and returns another Handler. +// Typically, the returned handler is a closure which does something with the ResponseWriter and Message passed +// to it, and then calls the handler passed as parameter to the MiddlewareFunc. +type MiddlewareFunc func(Handler) Handler + +// Middleware allows MiddlewareFunc to implement the middleware interface. +func (mw MiddlewareFunc) Middleware(handler Handler) Handler { + return mw(handler) +} + +// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router. +func (r *Router) Use(mwf ...MiddlewareFunc) { + r.middlewares = append(r.middlewares, mwf...) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/mux/muxResponseWriter.go b/vendor/github.com/plgd-dev/go-coap/v3/mux/muxResponseWriter.go new file mode 100644 index 0000000000..32ea9d4f7c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/mux/muxResponseWriter.go @@ -0,0 +1,47 @@ +package mux + +import ( + "io" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" +) + +// ToHandler converts mux handler to udp/dtls/tcp handler. +func ToHandler[C Conn](m Handler) func(w *responsewriter.ResponseWriter[C], r *pool.Message) { + return func(w *responsewriter.ResponseWriter[C], r *pool.Message) { + muxw := &muxResponseWriter[C]{ + w: w, + } + m.ServeCOAP(muxw, &Message{ + Message: r, + RouteParams: new(RouteParams), + }) + } +} + +type muxResponseWriter[C Conn] struct { + w *responsewriter.ResponseWriter[C] +} + +// SetResponse simplifies the setup of the response for the request. ETags must be set via options. For advanced setup, use Message(). +func (w *muxResponseWriter[C]) SetResponse(code codes.Code, contentFormat message.MediaType, d io.ReadSeeker, opts ...message.Option) error { + return w.w.SetResponse(code, contentFormat, d, opts...) +} + +// Conn peer connection. +func (w *muxResponseWriter[C]) Conn() Conn { + return w.w.Conn() +} + +// Message direct access to the response. +func (w *muxResponseWriter[C]) Message() *pool.Message { + return w.w.Message() +} + +// SetMessage replaces the response message. Ensure that Token, MessageID(udp), and Type(udp) messages are paired correctly. +func (w *muxResponseWriter[C]) SetMessage(msg *pool.Message) { + w.w.SetMessage(msg) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/mux/regexp.go b/vendor/github.com/plgd-dev/go-coap/v3/mux/regexp.go new file mode 100644 index 0000000000..f405335330 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/mux/regexp.go @@ -0,0 +1,171 @@ +// The code in this file is taken and adapated from gorilla/mux package. + +// Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mux + +import ( + "bytes" + "fmt" + "regexp" + "strconv" + "strings" +) + +func newRouteRegexp(path string) (*routeRegexp, error) { + // Check if it is well-formed. + idxs, errBraces := braceIndices(path) + if errBraces != nil { + return nil, errBraces + } + // Backup the original. + template := path + // Now let's parse it. + defaultPattern := "[^/]+" + varsN := make([]string, len(idxs)/2) + varsR := make([]*regexp.Regexp, len(idxs)/2) + pattern := bytes.NewBufferString("") + pattern.WriteByte('^') + reverse := bytes.NewBufferString("") + var end int + var err error + for i := 0; i < len(idxs); i += 2 { + // Set all values we are interested in. + raw := path[end:idxs[i]] + end = idxs[i+1] + parts := strings.SplitN(path[idxs[i]+1:end-1], ":", 2) + name := parts[0] + patt := defaultPattern + if len(parts) == 2 { + patt = parts[1] + } + // Name or pattern can't be empty. + if name == "" || patt == "" { + return nil, fmt.Errorf("mux: missing name or pattern in %q", + path[idxs[i]:end]) + } + // Build the regexp pattern. + fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt) + + // Build the reverse template. + fmt.Fprintf(reverse, "%s%%s", raw) + + // Append variable name and compiled pattern. + varsN[i/2] = name + varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt)) + if err != nil { + return nil, err + } + } + // Add the remaining. + raw := path[end:] + pattern.WriteString(regexp.QuoteMeta(raw)) + + pattern.WriteByte('$') + + // Compile full regexp. + reg, errCompile := regexp.Compile(pattern.String()) + if errCompile != nil { + return nil, errCompile + } + + // Check for capturing groups which used to work in older versions + if reg.NumSubexp() != len(idxs)/2 { + panic(fmt.Sprintf("route %s contains capture groups in its regexp. ", template) + + "Only non-capturing groups are accepted: e.g. (?:pattern) instead of (pattern)") + } + + // Done! + return &routeRegexp{ + template: template, + regexp: reg, + reverse: reverse.String(), + varsN: varsN, + varsR: varsR, + }, nil +} + +// routeRegexp stores a regexp to match a host or path and information to +// collect and validate route variables. +type routeRegexp struct { + // The unmodified template. + template string + + // Expanded regexp. + regexp *regexp.Regexp + // Reverse template. + reverse string + // Variable names. + varsN []string + // Variable regexps (validators). + varsR []*regexp.Regexp +} + +// varGroupName builds a capturing group name for the indexed variable. +func varGroupName(idx int) string { + return "v" + strconv.Itoa(idx) +} + +// braceIndices returns the first level curly brace indices from a string. +// It returns an error in case of unbalanced braces. +func braceIndices(s string) ([]int, error) { + var level, idx int + var idxs []int + for i := 0; i < len(s); i++ { + switch s[i] { + case '{': + if level++; level == 1 { + idx = i + } + case '}': + if level--; level == 0 { + idxs = append(idxs, idx, i+1) + } else if level < 0 { + return nil, fmt.Errorf("mux: unbalanced braces in %q", s) + } + } + } + if level != 0 { + return nil, fmt.Errorf("mux: unbalanced braces in %q", s) + } + return idxs, nil +} + +func (route *routeRegexp) extractRouteParams(path string, routeParams *RouteParams) { + matches := route.regexp.FindStringSubmatchIndex(path) + if len(matches) > 0 { + extractVars(path, matches, route.varsN, routeParams.Vars) + } +} + +func extractVars(input string, matches []int, names []string, output map[string]string) { + for i, name := range names { + output[name] = input[matches[2*i+2]:matches[2*i+3]] + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/mux/router.go b/vendor/github.com/plgd-dev/go-coap/v3/mux/router.go new file mode 100644 index 0000000000..5eea89461a --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/mux/router.go @@ -0,0 +1,222 @@ +package mux + +import ( + "errors" + "fmt" + "io" + "sync" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type ResponseWriter = interface { + SetResponse(code codes.Code, contentFormat message.MediaType, d io.ReadSeeker, opts ...message.Option) error + Conn() Conn + SetMessage(m *pool.Message) + Message() *pool.Message +} + +type Handler interface { + ServeCOAP(w ResponseWriter, r *Message) +} + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as COAP handlers. If f is a function +// with the appropriate signature, HandlerFunc(f) is a +// Handler object that calls f. +type HandlerFunc func(w ResponseWriter, r *Message) + +type ErrorFunc = func(error) + +// ServeCOAP calls f(w, r). +func (f HandlerFunc) ServeCOAP(w ResponseWriter, r *Message) { + f(w, r) +} + +// Router is an COAP request multiplexer. It matches the +// path name of each incoming request against a list of +// registered patterns add calls the handler for the pattern +// with same name. +// Router is also safe for concurrent access from multiple goroutines. +type Router struct { + middlewares []MiddlewareFunc + errors ErrorFunc + + m *sync.RWMutex + defaultHandler Handler // guarded by m + z map[string]Route // guarded by m +} + +type Route struct { + h Handler + pattern string + regexMatcher *routeRegexp +} + +func (route *Route) GetRouteRegexp() (string, error) { + if route.regexMatcher.regexp == nil { + return "", errors.New("mux: route does not have a regexp") + } + return route.regexMatcher.regexp.String(), nil +} + +// NewRouter allocates and returns a new Router. +func NewRouter() *Router { + router := &Router{ + middlewares: make([]MiddlewareFunc, 0, 2), + errors: func(err error) { + fmt.Println(err) + }, + + m: new(sync.RWMutex), + z: make(map[string]Route), + } + router.defaultHandler = HandlerFunc(func(w ResponseWriter, m *Message) { + if err := w.SetResponse(codes.NotFound, message.TextPlain, nil); err != nil { + router.errors(fmt.Errorf("router handler: cannot set response: %w", err)) + } + }) + return router +} + +// Does path match pattern? +func pathMatch(pattern Route, path string) bool { + return pattern.regexMatcher.regexp.MatchString(path) +} + +// Find a handler on a handler map given a path string +// Most-specific (longest) pattern wins +func (r *Router) Match(path string, routeParams *RouteParams) (matchedRoute *Route, matchedPattern string) { + r.m.RLock() + n := 0 + for pattern, route := range r.z { + if !pathMatch(route, path) { + continue + } + if matchedRoute == nil || len(pattern) > n { + n = len(pattern) + r := route + matchedRoute = &r + matchedPattern = pattern + } + } + r.m.RUnlock() + + if matchedRoute == nil { + return + } + + routeParams.Path = path + if routeParams.Vars == nil { + routeParams.Vars = make(map[string]string) + } + routeParams.PathTemplate = matchedPattern + matchedRoute.regexMatcher.extractRouteParams(path, routeParams) + + return +} + +// Handle adds a handler to the Router for pattern. +func (r *Router) Handle(pattern string, handler Handler) error { + switch pattern { + case "", "/": + pattern = "/" + } + + if handler == nil { + return errors.New("nil handler") + } + + routeRegex, err := newRouteRegexp(pattern) + if err != nil { + return err + } + + r.m.Lock() + r.z[pattern] = Route{h: handler, pattern: pattern, regexMatcher: routeRegex} + r.m.Unlock() + return nil +} + +// DefaultHandle set default handler to the Router +func (r *Router) DefaultHandle(handler Handler) { + r.m.Lock() + defer r.m.Unlock() + r.defaultHandler = handler +} + +// HandleFunc adds a handler function to the Router for pattern. +func (r *Router) HandleFunc(pattern string, handler func(w ResponseWriter, r *Message)) { + if err := r.Handle(pattern, HandlerFunc(handler)); err != nil { + r.errors(fmt.Errorf("cannot handle pattern(%v): %w", pattern, err)) + } +} + +// DefaultHandleFunc set a default handler function to the Router. +func (r *Router) DefaultHandleFunc(handler func(w ResponseWriter, r *Message)) { + r.DefaultHandle(HandlerFunc(handler)) +} + +// HandleRemove deregistrars the handler specific for pattern from the Router. +func (r *Router) HandleRemove(pattern string) error { + switch pattern { + case "", "/": + pattern = "/" + } + r.m.Lock() + defer r.m.Unlock() + if _, ok := r.z[pattern]; ok { + delete(r.z, pattern) + return nil + } + return errors.New("pattern is not registered in") +} + +// GetRoute obtains route from the pattern it has been assigned +func (r *Router) GetRoute(pattern string) *Route { + r.m.RLock() + defer r.m.RUnlock() + if route, ok := r.z[pattern]; ok { + return &route + } + return nil +} + +func (r *Router) GetRoutes() map[string]Route { + r.m.RLock() + defer r.m.RUnlock() + return r.z +} + +// ServeCOAP dispatches the request to the handler whose +// pattern most closely matches the request message. If DefaultServeMux +// is used the correct thing for DS queries is done: a possible parent +// is sought. +// If no handler is found a standard NotFound message is returned +func (r *Router) ServeCOAP(w ResponseWriter, req *Message) { + path, err := req.Options().Path() + r.m.RLock() + defaultHandler := r.defaultHandler + r.m.RUnlock() + if err != nil { + defaultHandler.ServeCOAP(w, req) + return + } + var h Handler + matchedMuxEntry, _ := r.Match(path, req.RouteParams) + if matchedMuxEntry == nil { + h = defaultHandler + } else { + h = matchedMuxEntry.h + } + if h == nil { + return + } + + for i := len(r.middlewares) - 1; i >= 0; i-- { + h = r.middlewares[i].Middleware(h) + } + h.ServeCOAP(w, req) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/blockwise/blockwise.go b/vendor/github.com/plgd-dev/go-coap/v3/net/blockwise/blockwise.go new file mode 100644 index 0000000000..1828d75eae --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/blockwise/blockwise.go @@ -0,0 +1,844 @@ +package blockwise + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/dsnet/golib/memfile" + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/pkg/cache" + "golang.org/x/sync/semaphore" +) + +// Block Option value is represented: https://tools.ietf.org/html/rfc7959#section-2.2 +// 0 +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// | NUM |M| SZX | +// +-+-+-+-+-+-+-+-+ +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | NUM |M| SZX | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 0 1 2 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | NUM |M| SZX | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +const ( + // max block size is 3bytes: https://tools.ietf.org/html/rfc7959#section-2.1 + maxBlockValue = 0xffffff + // maxBlockNumber is 20bits (NUM) + maxBlockNumber = 0xffff7 + // moreBlocksFollowingMask is represented by one bit (M) + moreBlocksFollowingMask = 0x8 + // szxMask last 3bits represents SZX (SZX) + szxMask = 0x7 +) + +// SZX enum representation for the size of the block: https://tools.ietf.org/html/rfc7959#section-2.2 +type SZX uint8 + +const ( + // SZX16 block of size 16bytes + SZX16 SZX = 0 + // SZX32 block of size 32bytes + SZX32 SZX = 1 + // SZX64 block of size 64bytes + SZX64 SZX = 2 + // SZX128 block of size 128bytes + SZX128 SZX = 3 + // SZX256 block of size 256bytes + SZX256 SZX = 4 + // SZX512 block of size 512bytes + SZX512 SZX = 5 + // SZX1024 block of size 1024bytes + SZX1024 SZX = 6 + // SZXBERT block of size n*1024bytes + SZXBERT SZX = 7 +) + +var szxToSize = map[SZX]int64{ + SZX16: 16, + SZX32: 32, + SZX64: 64, + SZX128: 128, + SZX256: 256, + SZX512: 512, + SZX1024: 1024, + SZXBERT: 1024, +} + +// Size number of bytes. +func (s SZX) Size() int64 { + val, ok := szxToSize[s] + if ok { + return val + } + return -1 +} + +// EncodeBlockOption encodes block values to coap option. +func EncodeBlockOption(szx SZX, blockNumber int64, moreBlocksFollowing bool) (uint32, error) { + if szx > SZXBERT { + return 0, ErrInvalidSZX + } + if blockNumber < 0 { + return 0, ErrBlockNumberExceedLimit + } + if blockNumber > maxBlockNumber { + return 0, ErrBlockNumberExceedLimit + } + blockVal := uint32(blockNumber << 4) + m := uint32(0) + if moreBlocksFollowing { + m = 1 + } + blockVal += m << 3 + blockVal += uint32(szx) + return blockVal, nil +} + +// DecodeBlockOption decodes coap block option to block values. +func DecodeBlockOption(blockVal uint32) (szx SZX, blockNumber int64, moreBlocksFollowing bool, err error) { + if blockVal > maxBlockValue { + err = ErrBlockInvalidSize + return + } + + szx = SZX(blockVal & szxMask) // masking for the SZX + if (blockVal & moreBlocksFollowingMask) != 0 { // masking for the "M" + moreBlocksFollowing = true + } + blockNumber = int64(blockVal) >> 4 // shifting out the SZX and M vals. leaving the block number behind + if blockNumber > maxBlockNumber { + err = ErrBlockNumberExceedLimit + } + return +} + +type Client interface { + // create message from pool + AcquireMessage(ctx context.Context) *pool.Message + // return back the message to the pool for next use + ReleaseMessage(m *pool.Message) +} + +type BlockWise[C Client] struct { + cc C + receivingMessagesCache *cache.Cache[uint64, *messageGuard] + sendingMessagesCache *cache.Cache[uint64, *pool.Message] + errors func(error) + getSentRequestFromOutside func(token message.Token) (*pool.Message, bool) + expiration time.Duration +} + +type messageGuard struct { + *pool.Message + *semaphore.Weighted +} + +func newRequestGuard(request *pool.Message) *messageGuard { + return &messageGuard{ + Message: request, + Weighted: semaphore.NewWeighted(1), + } +} + +// New provides blockwise. +// getSentRequestFromOutside must returns a copy of request which will be released after use. +func New[C Client]( + cc C, + expiration time.Duration, + errors func(error), + getSentRequestFromOutside func(token message.Token) (*pool.Message, bool), +) *BlockWise[C] { + if getSentRequestFromOutside == nil { + getSentRequestFromOutside = func(token message.Token) (*pool.Message, bool) { return nil, false } + } + return &BlockWise[C]{ + cc: cc, + receivingMessagesCache: cache.NewCache[uint64, *messageGuard](), + sendingMessagesCache: cache.NewCache[uint64, *pool.Message](), + errors: errors, + getSentRequestFromOutside: getSentRequestFromOutside, + expiration: expiration, + } +} + +func bufferSize(szx SZX, maxMessageSize uint32) int64 { + if szx < SZXBERT { + return szx.Size() + } + return (int64(maxMessageSize) / szx.Size()) * szx.Size() +} + +// CheckExpirations iterates over caches and remove expired items. +func (b *BlockWise[C]) CheckExpirations(now time.Time) { + b.receivingMessagesCache.CheckExpirations(now) + b.sendingMessagesCache.CheckExpirations(now) +} + +func (b *BlockWise[C]) cloneMessage(r *pool.Message) *pool.Message { + req := b.cc.AcquireMessage(r.Context()) + req.SetCode(r.Code()) + req.SetToken(r.Token()) + req.ResetOptionsTo(r.Options()) + req.SetType(r.Type()) + return req +} + +func payloadSizeError(err error) error { + return fmt.Errorf("cannot get size of payload: %w", err) +} + +// Do sends an coap message and returns an coap response via blockwise transfer. +func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do func(req *pool.Message) (*pool.Message, error)) (*pool.Message, error) { + if maxSzx > SZXBERT { + return nil, fmt.Errorf("invalid szx") + } + if len(r.Token()) == 0 { + return nil, fmt.Errorf("invalid token") + } + + expire, ok := r.Context().Deadline() + if !ok { + expire = time.Now().Add(b.expiration) + } + _, loaded := b.sendingMessagesCache.LoadOrStore(r.Token().Hash(), cache.NewElement(r, expire, nil)) + if loaded { + return nil, fmt.Errorf("invalid token") + } + defer b.sendingMessagesCache.Delete(r.Token().Hash()) + if r.Body() == nil { + return do(r) + } + payloadSize, err := r.BodySize() + if err != nil { + return nil, payloadSizeError(err) + } + if payloadSize <= maxSzx.Size() { + return do(r) + } + + switch r.Code() { + case codes.POST, codes.PUT: + break + default: + return nil, fmt.Errorf("unsupported command(%v)", r.Code()) + } + req := b.cloneMessage(r) + defer b.cc.ReleaseMessage(req) + req.SetOptionUint32(message.Size1, uint32(payloadSize)) + block, err := EncodeBlockOption(maxSzx, 0, true) + if err != nil { + return nil, fmt.Errorf("cannot encode block option(%v, %v, %v) to bw request: %w", maxSzx, 0, true, err) + } + req.SetOptionUint32(message.Block1, block) + newBufLen := bufferSize(maxSzx, maxMessageSize) + buf := make([]byte, newBufLen) + newOff, err := r.Body().Seek(0, io.SeekStart) + if err != nil { + return nil, fmt.Errorf("cannot seek in payload: %w", err) + } + readed, err := io.ReadFull(r.Body(), buf) + if errors.Is(err, io.ErrUnexpectedEOF) { + if newOff+int64(readed) == payloadSize { + err = nil + } + } + if err != nil { + return nil, fmt.Errorf("cannot read payload: %w", err) + } + buf = buf[:readed] + req.SetBody(bytes.NewReader(buf)) + return do(req) +} + +func newWriteRequestResponse[C Client](cc C, request *pool.Message) *responsewriter.ResponseWriter[C] { + req := cc.AcquireMessage(request.Context()) + req.SetCode(request.Code()) + req.SetToken(request.Token()) + req.ResetOptionsTo(request.Options()) + req.SetBody(request.Body()) + return responsewriter.New(req, cc, request.Options()...) +} + +// WriteMessage sends an coap message via blockwise transfer. +func (b *BlockWise[C]) WriteMessage(request *pool.Message, maxSZX SZX, maxMessageSize uint32, writeMessage func(r *pool.Message) error) error { + startSendingMessageBlock, err := EncodeBlockOption(maxSZX, 0, true) + if err != nil { + return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err) + } + + w := newWriteRequestResponse(b.cc, request) + err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock) + if err != nil { + return fmt.Errorf("cannot start writing request: %w", err) + } + return writeMessage(w.Message()) +} + +func fitSZX(r *pool.Message, blockType message.OptionID, maxSZX SZX) SZX { + block, err := r.GetOptionUint32(blockType) + if err != nil { + return maxSZX + } + + szx, _, _, err := DecodeBlockOption(block) + if err != nil { + return maxSZX + } + + if maxSZX > szx { + return szx + } + return maxSZX +} + +func (b *BlockWise[C]) sendEntityIncomplete(w *responsewriter.ResponseWriter[C], token message.Token) { + sendMessage := b.cc.AcquireMessage(w.Message().Context()) + sendMessage.SetCode(codes.RequestEntityIncomplete) + sendMessage.SetToken(token) + sendMessage.SetType(message.NonConfirmable) + w.SetMessage(sendMessage) +} + +func wantsToBeReceived(r *pool.Message) bool { + hasBlock1 := r.HasOption(message.Block1) + hasBlock2 := r.HasOption(message.Block2) + if hasBlock1 && (r.Code() == codes.POST || r.Code() == codes.PUT) { + // r contains payload which we received + return true + } + if hasBlock2 && (r.Code() >= codes.GET && r.Code() <= codes.DELETE) { + // r is command to get next block + return false + } + if r.Code() == codes.Continue { + return false + } + return true +} + +func (b *BlockWise[C]) getSendingMessageCode(token uint64) (codes.Code, bool) { + v := b.sendingMessagesCache.Load(token) + if v == nil { + return codes.Empty, false + } + return v.Data().Code(), true +} + +// Handle middleware which constructs COAP request from blockwise transfer and send COAP response via blockwise. +func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) { + if maxSZX > SZXBERT { + panic("invalid maxSZX") + } + token := r.Token() + + if len(token) == 0 { + err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next) + if err != nil { + b.sendEntityIncomplete(w, token) + b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err)) + } + return + } + tokenStr := token.Hash() + + sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(tokenStr) + if !sendingMessageExist || wantsToBeReceived(r) { + err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next) + if err != nil { + b.sendEntityIncomplete(w, token) + b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err)) + } + return + } + more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode) + if err != nil { + b.sendingMessagesCache.Delete(tokenStr) + b.errors(fmt.Errorf("continueSendingMessage(%v): %w", r, err)) + return + } + // For codes GET,POST,PUT,DELETE, we want them to wait for pairing response and then delete them when the full response comes in or when timeout occurs. + if !more && sendingMessageCode > codes.DELETE { + b.sendingMessagesCache.Delete(tokenStr) + } +} + +func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error { + startSendingMessageBlock, err := EncodeBlockOption(maxSZX, 0, true) + if err != nil { + return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err) + } + switch r.Code() { + case codes.Empty, codes.CSM, codes.Ping, codes.Pong, codes.Release, codes.Abort: + next(w, r) + return nil + case codes.GET, codes.DELETE: + maxSZX = fitSZX(r, message.Block2, maxSZX) + block, errG := r.GetOptionUint32(message.Block2) + if errG == nil { + r.Remove(message.Block2) + } + next(w, r) + if w.Message().Code() == codes.Content && errG == nil { + startSendingMessageBlock = block + } + case codes.POST, codes.PUT: + maxSZX = fitSZX(r, message.Block1, maxSZX) + errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block1, message.Size1) + if errP != nil { + return errP + } + default: + maxSZX = fitSZX(r, message.Block2, maxSZX) + errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block2, message.Size2) + if errP != nil { + return errP + } + } + return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock) +} + +func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX SZX, maxMessageSize uint32, block uint32) (sendMessage *pool.Message, more bool, err error) { + blockType := message.Block2 + sizeType := message.Size2 + token := sendingMessage.Token() + switch sendingMessage.Code() { + case codes.POST, codes.PUT: + blockType = message.Block1 + sizeType = message.Size1 + } + + szx, num, _, err := DecodeBlockOption(block) + if err != nil { + return nil, false, fmt.Errorf("cannot decode %v option: %w", blockType, err) + } + + sendMessage = b.cc.AcquireMessage(sendingMessage.Context()) + sendMessage.SetCode(sendingMessage.Code()) + sendMessage.ResetOptionsTo(sendingMessage.Options()) + sendMessage.SetToken(token) + sendMessage.SetType(sendingMessage.Type()) + payloadSize, err := sendingMessage.BodySize() + if err != nil { + b.cc.ReleaseMessage(sendMessage) + return nil, false, payloadSizeError(err) + } + if szx > maxSZX { + szx = maxSZX + } + newBufLen := bufferSize(szx, maxMessageSize) + off := num * szx.Size() + if blockType == message.Block1 { + // For block1, we need to skip the already sent bytes. + off += newBufLen + } + offSeek, err := sendingMessage.Body().Seek(off, io.SeekStart) + if err != nil { + b.cc.ReleaseMessage(sendMessage) + return nil, false, fmt.Errorf("cannot seek in response: %w", err) + } + if off != offSeek { + b.cc.ReleaseMessage(sendMessage) + return nil, false, fmt.Errorf("cannot seek to requested offset(%v != %v)", off, offSeek) + } + buf := make([]byte, 1024) + if int64(len(buf)) < newBufLen { + buf = make([]byte, newBufLen) + } + buf = buf[:newBufLen] + + readed, err := io.ReadFull(sendingMessage.Body(), buf) + if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { + if offSeek+int64(readed) == payloadSize { + err = nil + } + } + if err != nil { + b.cc.ReleaseMessage(sendMessage) + return nil, false, fmt.Errorf("cannot read response: %w", err) + } + + buf = buf[:readed] + sendMessage.SetBody(bytes.NewReader(buf)) + more = true + if offSeek+int64(readed) == payloadSize { + more = false + } + sendMessage.SetOptionUint32(sizeType, uint32(payloadSize)) + num = (offSeek) / szx.Size() + block, err = EncodeBlockOption(szx, num, more) + if err != nil { + b.cc.ReleaseMessage(sendMessage) + return nil, false, fmt.Errorf("cannot encode block option(%v,%v,%v): %w", szx, num, more, err) + } + sendMessage.SetOptionUint32(blockType, block) + return sendMessage, more, nil +} + +func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, sendingMessageCode codes.Code /* msg *pool.Message*/) (bool, error) { + blockType := message.Block2 + switch sendingMessageCode { + case codes.POST, codes.PUT: + blockType = message.Block1 + } + + block, err := r.GetOptionUint32(blockType) + if err != nil { + return false, fmt.Errorf("cannot get %v option: %w", blockType, err) + } + var sendMessage *pool.Message + var more bool + b.sendingMessagesCache.LoadWithFunc(r.Token().Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { + sendMessage, more, err = b.createSendingMessage(value.Data(), maxSZX, maxMessageSize, block) + if err != nil { + err = fmt.Errorf("cannot create sending message: %w", err) + } + return nil + }) + if err == nil && sendMessage == nil { + err = fmt.Errorf("cannot find sending message for token(%v)", r.Token()) + } + if err != nil { + return false, fmt.Errorf("handleSendingMessage: %w", err) + } + w.SetMessage(sendMessage) + return more, err +} + +func isObserveResponse(msg *pool.Message) bool { + _, err := msg.GetOptionUint32(message.Observe) + if err != nil { + return false + } + return msg.Code() >= codes.Created +} + +func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32) error { + payloadSize, err := w.Message().BodySize() + if err != nil { + return payloadSizeError(err) + } + + if payloadSize < maxSZX.Size() { + return nil + } + sendingMessage, _, err := b.createSendingMessage(w.Message(), maxSZX, maxMessageSize, block) + if err != nil { + return fmt.Errorf("handleSendingMessage: cannot create sending message: %w", err) + } + originalSendingMessage := w.Swap(sendingMessage) + if isObserveResponse(w.Message()) { + b.cc.ReleaseMessage(originalSendingMessage) + // https://tools.ietf.org/html/rfc7959#section-2.6 - we don't need store it because client will be get values via GET. + return nil + } + expire, ok := sendingMessage.Context().Deadline() + if !ok { + expire = time.Now().Add(b.expiration) + } + el, loaded := b.sendingMessagesCache.LoadOrStore(sendingMessage.Token().Hash(), cache.NewElement(originalSendingMessage, expire, nil)) + if loaded { + defer b.cc.ReleaseMessage(originalSendingMessage) + return fmt.Errorf("cannot add message (%v) to sending message cache: message(%v) with token(%v) already exist", originalSendingMessage, el.Data(), sendingMessage.Token()) + } + return nil +} + +func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message { + data, ok := b.sendingMessagesCache.LoadWithFunc(token.Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { + if value == nil { + return nil + } + v := value.Data() + msg := b.cc.AcquireMessage(v.Context()) + msg.SetCode(v.Code()) + msg.SetToken(v.Token()) + msg.ResetOptionsTo(v.Options()) + msg.SetType(v.Type()) + return cache.NewElement(msg, value.ValidUntil.Load(), nil) + }) + if ok { + return data.Data() + } + globalRequest, ok := b.getSentRequestFromOutside(token) + if ok { + return globalRequest + } + return nil +} + +func (b *BlockWise[C]) handleObserveResponse(sentRequest *pool.Message) (message.Token, time.Time, error) { + // https://tools.ietf.org/html/rfc7959#section-2.6 - performs GET with new token. + if sentRequest == nil { + return nil, time.Time{}, fmt.Errorf("observation is not registered") + } + token, err := message.GetToken() + if err != nil { + return nil, time.Time{}, fmt.Errorf("cannot get token for create GET request: %w", err) + } + validUntil := time.Now().Add(b.expiration) // context of observation can be expired. + bwSentRequest := b.cloneMessage(sentRequest) + bwSentRequest.SetToken(token) + _, loaded := b.sendingMessagesCache.LoadOrStore(token.Hash(), cache.NewElement(bwSentRequest, validUntil, nil)) + if loaded { + return nil, time.Time{}, fmt.Errorf("cannot process message: message with token already exist") + } + return token, validUntil, nil +} + +func (b *BlockWise[C]) getValidUntil(sentRequest *pool.Message) time.Time { + validUntil := time.Now().Add(b.expiration) + if sentRequest != nil { + if deadline, ok := sentRequest.Context().Deadline(); ok { + return deadline + } + } + return validUntil +} + +func getSzx(szx, maxSzx SZX) SZX { + if szx > maxSzx { + return maxSzx + } + return szx +} + +func (b *BlockWise[C]) getPayloadFromCachedReceivedMessage(r, cachedReceivedMessage *pool.Message) (*memfile.File, int64, error) { + payloadFile, ok := cachedReceivedMessage.Body().(*memfile.File) + if !ok { + return nil, 0, fmt.Errorf("invalid body type(%T) stored in receivingMessagesCache", cachedReceivedMessage.Body()) + } + rETAG, errETAG := r.GetOptionBytes(message.ETag) + cachedReceivedMessageETAG, errCachedReceivedMessageETAG := cachedReceivedMessage.GetOptionBytes(message.ETag) + switch { + case errETAG == nil && errCachedReceivedMessageETAG != nil: + if len(cachedReceivedMessageETAG) > 0 { // make sure there is an etag there + return nil, 0, fmt.Errorf("received message doesn't contains ETAG but cached received message contains it(%v)", cachedReceivedMessageETAG) + } + case errETAG != nil && errCachedReceivedMessageETAG == nil: + if len(rETAG) > 0 { // make sure there is an etag there + return nil, 0, fmt.Errorf("received message contains ETAG(%v) but cached received message doesn't", rETAG) + } + case !bytes.Equal(rETAG, cachedReceivedMessageETAG): + // ETAG was changed - drop data and set new ETAG + cachedReceivedMessage.SetOptionBytes(message.ETag, rETAG) + if err := payloadFile.Truncate(0); err != nil { + return nil, 0, fmt.Errorf("cannot truncate cached request: %w", err) + } + } + + payloadSize, err := cachedReceivedMessage.BodySize() + if err != nil { + return nil, 0, payloadSizeError(err) + } + return payloadFile, payloadSize, nil +} + +func copyToPayloadFromOffset(r *pool.Message, payloadFile *memfile.File, offset int64) (int64, error) { + payloadSize := int64(0) + copyn, err := payloadFile.Seek(offset, io.SeekStart) + if err != nil { + return 0, fmt.Errorf("cannot seek to off(%v) of cached request: %w", offset, err) + } + written := int64(0) + if r.Body() != nil { + _, err = r.Body().Seek(0, io.SeekStart) + if err != nil { + return 0, fmt.Errorf("cannot seek to start of request: %w", err) + } + written, err = io.Copy(payloadFile, r.Body()) + if err != nil { + return 0, fmt.Errorf("cannot copy to cached request: %w", err) + } + } + payloadSize = copyn + written + err = payloadFile.Truncate(payloadSize) + if err != nil { + return 0, fmt.Errorf("cannot truncate cached request: %w", err) + } + return payloadSize, nil +} + +func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Message, tokenStr uint64, validUntil time.Time) (*pool.Message, func(), error) { + cannotLockError := func(err error) error { + return fmt.Errorf("processReceivedMessage: cannot lock message: %w", err) + } + if mg != nil { + errA := mg.Acquire(mg.Context(), 1) + if errA != nil { + return nil, nil, cannotLockError(errA) + } + return mg.Message, func() { mg.Release(1) }, nil + } + closeFnList := []func(){} + appendToClose := func(m *messageGuard) { + closeFnList = append(closeFnList, func() { + m.Release(1) + }) + } + closeFn := func() { + for i := range closeFnList { + closeFnList[len(closeFnList)-1-i]() + } + } + msg := b.cc.AcquireMessage(r.Context()) + msg.ResetOptionsTo(r.Options()) + msg.SetToken(r.Token()) + msg.SetSequence(r.Sequence()) + msg.SetBody(memfile.New(make([]byte, 0, 1024))) + msg.SetCode(r.Code()) + mg = newRequestGuard(msg) + errA := mg.Acquire(mg.Context(), 1) + if errA != nil { + return nil, nil, cannotLockError(errA) + } + appendToClose(mg) + element, loaded := b.receivingMessagesCache.LoadOrStore(tokenStr, cache.NewElement(mg, validUntil, func(d *messageGuard) { + if d == nil { + return + } + b.sendingMessagesCache.Delete(tokenStr) + })) + // request was already stored in cache, silently + if loaded { + mg = element.Data() + if mg == nil { + closeFn() + return nil, nil, fmt.Errorf("request was already stored in cache") + } + errA := mg.Acquire(mg.Context(), 1) + if errA != nil { + closeFn() + return nil, nil, cannotLockError(errA) + } + appendToClose(mg) + } + + return mg.Message, closeFn, nil +} + +//nolint:gocyclo,gocognit +func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), blockType message.OptionID, sizeType message.OptionID) error { + // TODO: lower cyclomatic complexity + token := r.Token() + if len(token) == 0 { + next(w, r) + return nil + } + if r.Code() == codes.GET || r.Code() == codes.DELETE { + next(w, r) + return nil + } + block, err := r.GetOptionUint32(blockType) + if err != nil { + if errors.Is(err, message.ErrOptionNotFound) { + next(w, r) + return nil + } + return fmt.Errorf("cannot get Block(optionID=%d) option: %w", blockType, err) + } + szx, num, more, err := DecodeBlockOption(block) + if err != nil { + return fmt.Errorf("cannot decode block option: %w", err) + } + sentRequest := b.getSentRequest(token) + if sentRequest != nil { + defer b.cc.ReleaseMessage(sentRequest) + } + validUntil := b.getValidUntil(sentRequest) + if blockType == message.Block2 && sentRequest == nil { + return fmt.Errorf("cannot request body without paired request") + } + if isObserveResponse(r) { + token, validUntil, err = b.handleObserveResponse(sentRequest) + if err != nil { + return fmt.Errorf("cannot process message: %w", err) + } + } + + tokenStr := token.Hash() + var cachedReceivedMessageGuard *messageGuard + if e := b.receivingMessagesCache.Load(tokenStr); e != nil { + cachedReceivedMessageGuard = e.Data() + } + if cachedReceivedMessageGuard == nil { + szx = getSzx(szx, maxSzx) + // if there is no more then just forward req to next handler + if !more { + next(w, r) + return nil + } + } + cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, tokenStr, validUntil) + if err != nil { + return err + } + defer closeCachedReceivedMessage() + + defer func(err *error) { + if *err != nil { + b.receivingMessagesCache.Delete(tokenStr) + } + }(&err) + payloadFile, payloadSize, err := b.getPayloadFromCachedReceivedMessage(r, cachedReceivedMessage) + if err != nil { + return fmt.Errorf("cannot get payload: %w", err) + } + off := num * szx.Size() + if off == payloadSize { + payloadSize, err = copyToPayloadFromOffset(r, payloadFile, off) + if err != nil { + return fmt.Errorf("cannot copy data to payload: %w", err) + } + if !more { + b.receivingMessagesCache.Delete(tokenStr) + cachedReceivedMessage.Remove(blockType) + cachedReceivedMessage.Remove(sizeType) + cachedReceivedMessage.SetType(r.Type()) + if !bytes.Equal(cachedReceivedMessage.Token(), token) { + b.sendingMessagesCache.Delete(tokenStr) + } + _, errS := cachedReceivedMessage.Body().Seek(0, io.SeekStart) + if errS != nil { + return fmt.Errorf("cannot seek to start of cachedReceivedMessage request: %w", errS) + } + next(w, cachedReceivedMessage) + return nil + } + } + + szx = getSzx(szx, maxSzx) + sendMessage := b.cc.AcquireMessage(r.Context()) + sendMessage.SetToken(token) + if blockType == message.Block2 { + num = payloadSize / szx.Size() + sendMessage.ResetOptionsTo(sentRequest.Options()) + sendMessage.SetCode(sentRequest.Code()) + sendMessage.Remove(message.Observe) + sendMessage.Remove(message.Block1) + sendMessage.Remove(message.Size1) + } else { + sendMessage.SetCode(codes.Continue) + } + respBlock, err := EncodeBlockOption(szx, num, more) + if err != nil { + b.cc.ReleaseMessage(sendMessage) + return fmt.Errorf("cannot encode block option(%v,%v,%v): %w", szx, num, more, err) + } + sendMessage.SetOptionUint32(blockType, respBlock) + w.SetMessage(sendMessage) + return nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/blockwise/error.go b/vendor/github.com/plgd-dev/go-coap/v3/net/blockwise/error.go new file mode 100644 index 0000000000..6bc7763de8 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/blockwise/error.go @@ -0,0 +1,26 @@ +package blockwise + +import "errors" + +var ( + // ErrBlockNumberExceedLimit block number exceeded the limit 1,048,576 + ErrBlockNumberExceedLimit = errors.New("block number exceeded the limit 1,048,576") + + // ErrBlockInvalidSize block has invalid size + ErrBlockInvalidSize = errors.New("block has invalid size") + + // ErrInvalidOptionBlock2 message has invalid value of Block2 + ErrInvalidOptionBlock2 = errors.New("message has invalid value of Block2") + + // ErrInvalidOptionBlock1 message has invalid value of Block1 + ErrInvalidOptionBlock1 = errors.New("message has invalid value of Block1") + + // ErrInvalidResponseCode response code has invalid value + ErrInvalidResponseCode = errors.New("response code has invalid value") + + // ErrInvalidPayloadSize invalid payload size + ErrInvalidPayloadSize = errors.New("invalid payload size") + + // ErrInvalidSZX invalid block-wise transfer szx + ErrInvalidSZX = errors.New("invalid block-wise transfer szx") +) diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/client/client.go b/vendor/github.com/plgd-dev/go-coap/v3/net/client/client.go new file mode 100644 index 0000000000..b2cdca9ec1 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/client/client.go @@ -0,0 +1,241 @@ +package client + +import ( + "context" + "fmt" + "io" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + limitparallelrequests "github.com/plgd-dev/go-coap/v3/net/client/limitParallelRequests" + "github.com/plgd-dev/go-coap/v3/net/observation" +) + +type ( + GetTokenFunc = func() (message.Token, error) +) + +type Conn interface { + // create message from pool + AcquireMessage(ctx context.Context) *pool.Message + // return back the message to the pool for next use + ReleaseMessage(m *pool.Message) + WriteMessage(req *pool.Message) error + AsyncPing(receivedPong func()) (func(), error) + Context() context.Context +} + +type Client[C Conn] struct { + cc Conn + observationHandler *observation.Handler[C] + getToken GetTokenFunc + *limitparallelrequests.LimitParallelRequests +} + +func New[C Conn](cc C, observationHandler *observation.Handler[C], getToken GetTokenFunc, limitParallelRequests *limitparallelrequests.LimitParallelRequests) *Client[C] { + return &Client[C]{ + cc: cc, + observationHandler: observationHandler, + getToken: getToken, + LimitParallelRequests: limitParallelRequests, + } +} + +func (c *Client[C]) GetToken() (message.Token, error) { + return c.getToken() +} + +// NewGetRequest creates get request. +// +// Use ctx to set timeout. +func (c *Client[C]) NewGetRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) { + req := c.cc.AcquireMessage(ctx) + token, err := c.GetToken() + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + err = req.SetupGet(path, token, opts...) + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + return req, nil +} + +// Get issues a GET to the specified path. +// +// Use ctx to set timeout. +// +// An error is returned if by failure to speak COAP (such as a network connectivity problem). +// Any status code doesn't cause an error. +func (c *Client[C]) Get(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) { + req, err := c.NewGetRequest(ctx, path, opts...) + if err != nil { + return nil, fmt.Errorf("cannot create get request: %w", err) + } + defer c.cc.ReleaseMessage(req) + return c.Do(req) +} + +type Observation = interface { + Cancel(ctx context.Context, opts ...message.Option) error + Canceled() bool +} + +// NewObserveRequest creates observe request. +// +// Use ctx to set timeout. +func (c *Client[C]) NewObserveRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) { + req, err := c.NewGetRequest(ctx, path, opts...) + if err != nil { + return nil, err + } + req.SetObserve(0) + return req, nil +} + +// Observe subscribes for every change of resource on path. +func (c *Client[C]) Observe(ctx context.Context, path string, observeFunc func(req *pool.Message), opts ...message.Option) (Observation, error) { + req, err := c.NewObserveRequest(ctx, path, opts...) + if err != nil { + return nil, err + } + defer c.cc.ReleaseMessage(req) + return c.DoObserve(req, observeFunc) +} + +func (c *Client[C]) GetObservationRequest(token message.Token) (*pool.Message, bool) { + return c.observationHandler.GetObservationRequest(token) +} + +// NewPostRequest creates post request. +// +// Use ctx to set timeout. +// +// An error is returned if by failure to speak COAP (such as a network connectivity problem). +// Any status code doesn't cause an error. +// +// If payload is nil then content format is not used. +func (c *Client[C]) NewPostRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) { + req := c.cc.AcquireMessage(ctx) + token, err := c.GetToken() + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + err = req.SetupPost(path, token, contentFormat, payload, opts...) + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + return req, nil +} + +// Post issues a POST to the specified path. +// +// Use ctx to set timeout. +// +// An error is returned if by failure to speak COAP (such as a network connectivity problem). +// Any status code doesn't cause an error. +// +// If payload is nil then content format is not used. +func (c *Client[C]) Post(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) { + req, err := c.NewPostRequest(ctx, path, contentFormat, payload, opts...) + if err != nil { + return nil, fmt.Errorf("cannot create post request: %w", err) + } + defer c.cc.ReleaseMessage(req) + return c.Do(req) +} + +// NewPutRequest creates put request. +// +// Use ctx to set timeout. +// +// If payload is nil then content format is not used. +func (c *Client[C]) NewPutRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) { + req := c.cc.AcquireMessage(ctx) + token, err := c.GetToken() + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + err = req.SetupPut(path, token, contentFormat, payload, opts...) + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + return req, nil +} + +// Put issues a PUT to the specified path. +// +// Use ctx to set timeout. +// +// An error is returned if by failure to speak COAP (such as a network connectivity problem). +// Any status code doesn't cause an error. +// +// If payload is nil then content format is not used. +func (c *Client[C]) Put(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) { + req, err := c.NewPutRequest(ctx, path, contentFormat, payload, opts...) + if err != nil { + return nil, fmt.Errorf("cannot create put request: %w", err) + } + defer c.cc.ReleaseMessage(req) + return c.Do(req) +} + +// NewDeleteRequest creates delete request. +// +// Use ctx to set timeout. +func (c *Client[C]) NewDeleteRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) { + req := c.cc.AcquireMessage(ctx) + token, err := c.GetToken() + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + err = req.SetupDelete(path, token, opts...) + if err != nil { + c.cc.ReleaseMessage(req) + return nil, err + } + return req, nil +} + +// Delete deletes the resource identified by the request path. +// +// Use ctx to set timeout. +func (c *Client[C]) Delete(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) { + req, err := c.NewDeleteRequest(ctx, path, opts...) + if err != nil { + return nil, fmt.Errorf("cannot create delete request: %w", err) + } + defer c.cc.ReleaseMessage(req) + return c.Do(req) +} + +// Ping issues a PING to the client and waits for PONG response. +// +// Use ctx to set timeout. +func (c *Client[C]) Ping(ctx context.Context) error { + resp := make(chan bool, 1) + receivedPong := func() { + select { + case resp <- true: + default: + } + } + cancel, err := c.cc.AsyncPing(receivedPong) + if err != nil { + return err + } + defer cancel() + select { + case <-resp: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/client/limitParallelRequests/limitParallelRequests.go b/vendor/github.com/plgd-dev/go-coap/v3/net/client/limitParallelRequests/limitParallelRequests.go new file mode 100644 index 0000000000..c718fd2115 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/client/limitParallelRequests/limitParallelRequests.go @@ -0,0 +1,135 @@ +package limitparallelrequests + +import ( + "context" + "fmt" + "hash/crc64" + "math" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "golang.org/x/sync/semaphore" +) + +type ( + DoFunc = func(req *pool.Message) (*pool.Message, error) + DoObserveFunc = func(req *pool.Message, observeFunc func(req *pool.Message)) (Observation, error) +) + +type Observation = interface { + Cancel(ctx context.Context, opts ...message.Option) error + Canceled() bool +} + +type endpointQueue struct { + processedCounter int64 + orderedRequest []chan struct{} +} + +type LimitParallelRequests struct { + endpointLimit int64 + limit *semaphore.Weighted + do DoFunc + doObserve DoObserveFunc + // only one request can be processed by one endpoint + endpointQueues *coapSync.Map[uint64, *endpointQueue] +} + +// New creates new LimitParallelRequests. When limit, endpointLimit == 0, then limit is not used. +func New(limit, endpointLimit int64, do DoFunc, doObserve DoObserveFunc) *LimitParallelRequests { + if limit <= 0 { + limit = math.MaxInt64 + } + if endpointLimit <= 0 { + endpointLimit = math.MaxInt64 + } + return &LimitParallelRequests{ + limit: semaphore.NewWeighted(limit), + endpointLimit: endpointLimit, + do: do, + doObserve: doObserve, + endpointQueues: coapSync.NewMap[uint64, *endpointQueue](), + } +} + +func hash(opts message.Options) uint64 { + h := crc64.New(crc64.MakeTable(crc64.ISO)) + for _, opt := range opts { + if opt.ID == message.URIPath { + _, _ = h.Write(opt.Value) // hash never returns an error + } + } + return h.Sum64() +} + +func (c *LimitParallelRequests) acquireEndpoint(ctx context.Context, endpointLimitKey uint64) error { + reqChan := make(chan struct{}) // channel is closed when request can be processed by releaseEndpoint + _, _ = c.endpointQueues.LoadOrStoreWithFunc(endpointLimitKey, func(value *endpointQueue) *endpointQueue { + if value.processedCounter < c.endpointLimit { + close(reqChan) + value.processedCounter++ + return value + } + value.orderedRequest = append(value.orderedRequest, reqChan) + return value + }, func() *endpointQueue { + close(reqChan) + return &endpointQueue{ + processedCounter: 1, + } + }) + select { + case <-ctx.Done(): + c.releaseEndpoint(endpointLimitKey) + return ctx.Err() + case <-reqChan: + return nil + } +} + +func (c *LimitParallelRequests) releaseEndpoint(endpointLimitKey uint64) { + _, _ = c.endpointQueues.ReplaceWithFunc(endpointLimitKey, func(oldValue *endpointQueue, oldLoaded bool) (newValue *endpointQueue, doDelete bool) { + if oldLoaded { + if len(oldValue.orderedRequest) > 0 { + reqChan := oldValue.orderedRequest[0] + oldValue.orderedRequest = oldValue.orderedRequest[1:] + close(reqChan) + } else { + oldValue.processedCounter-- + if oldValue.processedCounter == 0 { + return nil, true + } + } + return oldValue, false + } + return nil, true + }) +} + +func (c *LimitParallelRequests) Do(req *pool.Message) (*pool.Message, error) { + endpointLimitKey := hash(req.Options()) + if err := c.acquireEndpoint(req.Context(), endpointLimitKey); err != nil { + return nil, fmt.Errorf("cannot process request %v for client endpoint limit: %w", req, err) + } + defer c.releaseEndpoint(endpointLimitKey) + if err := c.limit.Acquire(req.Context(), 1); err != nil { + return nil, fmt.Errorf("cannot process request %v for client limit: %w", req, err) + } + defer c.limit.Release(1) + return c.do(req) +} + +func (c *LimitParallelRequests) DoObserve(req *pool.Message, observeFunc func(req *pool.Message)) (Observation, error) { + endpointLimitKey := hash(req.Options()) + if err := c.acquireEndpoint(req.Context(), endpointLimitKey); err != nil { + return nil, fmt.Errorf("cannot process observe request %v for client endpoint limit: %w", req, err) + } + defer c.releaseEndpoint(endpointLimitKey) + err := c.limit.Acquire(req.Context(), 1) + if err != nil { + return nil, fmt.Errorf("cannot process observe request %v for client limit: %w", req, err) + } + defer c.limit.Release(1) + return c.doObserve(req, observeFunc) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/client/receivedMessageReader.go b/vendor/github.com/plgd-dev/go-coap/v3/net/client/receivedMessageReader.go new file mode 100644 index 0000000000..acbf4ccab0 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/client/receivedMessageReader.go @@ -0,0 +1,96 @@ +package client + +import ( + "sync" + + "github.com/plgd-dev/go-coap/v3/message/pool" + "go.uber.org/atomic" +) + +type ReceivedMessageReaderClient interface { + Done() <-chan struct{} + ProcessReceivedMessage(req *pool.Message) +} + +type ReceivedMessageReader[C ReceivedMessageReaderClient] struct { + queue chan *pool.Message + cc C + + private struct { + mutex sync.Mutex + loopDone chan struct{} + readingMessages *atomic.Bool + } +} + +// NewReceivedMessageReader creates a new ReceivedMessageReader[C] instance. +func NewReceivedMessageReader[C ReceivedMessageReaderClient](cc C, queueSize int) *ReceivedMessageReader[C] { + r := ReceivedMessageReader[C]{ + queue: make(chan *pool.Message, queueSize), + cc: cc, + private: struct { + mutex sync.Mutex + loopDone chan struct{} + readingMessages *atomic.Bool + }{ + loopDone: make(chan struct{}), + readingMessages: atomic.NewBool(true), + }, + } + + go r.loop(r.private.loopDone, r.private.readingMessages) + return &r +} + +// C returns the channel to push received messages to. +func (r *ReceivedMessageReader[C]) C() chan<- *pool.Message { + return r.queue +} + +// The loop function continuously listens to messages. IT can be replaced with a new one by calling the TryToReplaceLoop function, +// ensuring that only one loop is reading from the queue at a time. +// The loopDone channel is used to signal when the loop should be closed. +// The readingMessages variable is used to indicate if the loop is currently reading from the queue. +// When the loop is not reading from the queue, it sets readingMessages to false, and when it starts reading again, it sets it to true. +// If the client is closed, the loop also closes. +func (r *ReceivedMessageReader[C]) loop(loopDone chan struct{}, readingMessages *atomic.Bool) { + for { + select { + // if the loop is replaced, the old loop will be closed + case <-loopDone: + return + // process received message until the queue is empty + case req := <-r.queue: + // This signalizes that the loop is not reading messages. + readingMessages.Store(false) + r.cc.ProcessReceivedMessage(req) + // This signalizes that the loop is reading messages. We call mutex because we want to ensure that TryToReplaceLoop has ended and + // loopDone is closed if it was replaced. + r.private.mutex.Lock() + readingMessages.Store(true) + r.private.mutex.Unlock() + // if the client is closed, the loop will be closed + case <-r.cc.Done(): + return + } + } +} + +// TryToReplaceLoop function attempts to replace the loop with a new one, +// but only if the loop is not currently reading messages. If the loop is reading messages, +// the function returns immediately. If the loop is not reading messages, the current loop is closed, +// and new loopDone and readingMessages channels and variables are created. +func (r *ReceivedMessageReader[C]) TryToReplaceLoop() { + r.private.mutex.Lock() + if r.private.readingMessages.Load() { + r.private.mutex.Unlock() + return + } + defer r.private.mutex.Unlock() + close(r.private.loopDone) + loopDone := make(chan struct{}) + readingMessages := atomic.NewBool(true) + r.private.loopDone = loopDone + r.private.readingMessages = readingMessages + go r.loop(loopDone, readingMessages) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/conn.go b/vendor/github.com/plgd-dev/go-coap/v3/net/conn.go new file mode 100644 index 0000000000..8d6faabea9 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/conn.go @@ -0,0 +1,128 @@ +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "go.uber.org/atomic" +) + +// Conn is a generic stream-oriented network connection that provides Read/Write with context. +// +// Multiple goroutines may invoke methods on a Conn simultaneously. +type Conn struct { + connection net.Conn + closed atomic.Bool + handshakeContext func(ctx context.Context) error + lock sync.Mutex +} + +// NewConn creates connection over net.Conn. +func NewConn(c net.Conn) *Conn { + connection := Conn{ + connection: c, + } + + if v, ok := c.(interface { + HandshakeContext(ctx context.Context) error + }); ok { + connection.handshakeContext = v.HandshakeContext + } + + return &connection +} + +// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. +func (c *Conn) LocalAddr() net.Addr { + return c.connection.LocalAddr() +} + +// NetConn returns the underlying connection that is wrapped by c. The Conn returned is shared by all invocations of Connection, so do not modify it. +func (c *Conn) NetConn() net.Conn { + return c.connection +} + +// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. +func (c *Conn) RemoteAddr() net.Addr { + return c.connection.RemoteAddr() +} + +// Close closes the connection. +func (c *Conn) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return nil + } + return c.connection.Close() +} + +func (c *Conn) handshake(ctx context.Context) error { + if c.handshakeContext != nil { + err := c.handshakeContext(ctx) + if err == nil { + return nil + } + errC := c.Close() + if errC == nil { + return err + } + return fmt.Errorf("%v", []error{err, errC}) + } + return nil +} + +// WriteWithContext writes data with context. +func (c *Conn) WriteWithContext(ctx context.Context, data []byte) error { + if err := c.handshake(ctx); err != nil { + return err + } + written := 0 + c.lock.Lock() + defer c.lock.Unlock() + for written < len(data) { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if c.closed.Load() { + return ErrConnectionIsClosed + } + n, err := c.connection.Write(data[written:]) + if err != nil { + return err + } + written += n + } + return nil +} + +// ReadFullWithContext reads stream with context until whole buffer is satisfied. +func (c *Conn) ReadFullWithContext(ctx context.Context, buffer []byte) error { + offset := 0 + for offset < len(buffer) { + n, err := c.ReadWithContext(ctx, buffer[offset:]) + if err != nil { + return fmt.Errorf("cannot read full from connection: %w", err) + } + offset += n + } + return nil +} + +// ReadWithContext reads stream with context. +func (c *Conn) ReadWithContext(ctx context.Context, buffer []byte) (int, error) { + select { + case <-ctx.Done(): + return -1, ctx.Err() + default: + } + if c.closed.Load() { + return -1, ErrConnectionIsClosed + } + if err := c.handshake(ctx); err != nil { + return -1, err + } + return c.connection.Read(buffer) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/connUDP.go b/vendor/github.com/plgd-dev/go-coap/v3/net/connUDP.go new file mode 100644 index 0000000000..f749b22d26 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/connUDP.go @@ -0,0 +1,481 @@ +package net + +import ( + "context" + "fmt" + "net" + "strings" + "time" + + "go.uber.org/atomic" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// UDPConn is a udp connection provides Read/Write with context. +// +// Multiple goroutines may invoke methods on a UDPConn simultaneously. +type UDPConn struct { + packetConn packetConn + network string + connection *net.UDPConn + errors func(err error) + closed atomic.Bool +} + +type ControlMessage struct { + Src net.IP // source address, specifying only + IfIndex int // interface index, must be 1 <= value when specifying +} + +type packetConn interface { + SetWriteDeadline(t time.Time) error + WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) + SetMulticastInterface(ifi *net.Interface) error + SetMulticastHopLimit(hoplim int) error + SetMulticastLoopback(on bool) error + JoinGroup(ifi *net.Interface, group net.Addr) error + LeaveGroup(ifi *net.Interface, group net.Addr) error +} + +type packetConnIPv4 struct { + packetConnIPv4 *ipv4.PacketConn +} + +func newPacketConnIPv4(p *ipv4.PacketConn) *packetConnIPv4 { + return &packetConnIPv4{p} +} + +func (p *packetConnIPv4) SetMulticastInterface(ifi *net.Interface) error { + return p.packetConnIPv4.SetMulticastInterface(ifi) +} + +func (p *packetConnIPv4) SetWriteDeadline(t time.Time) error { + return p.packetConnIPv4.SetWriteDeadline(t) +} + +func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) { + var c *ipv4.ControlMessage + if cm != nil { + c = &ipv4.ControlMessage{ + Src: cm.Src, + IfIndex: cm.IfIndex, + } + } + return p.packetConnIPv4.WriteTo(b, c, dst) +} + +func (p *packetConnIPv4) SetMulticastHopLimit(hoplim int) error { + return p.packetConnIPv4.SetMulticastTTL(hoplim) +} + +func (p *packetConnIPv4) SetMulticastLoopback(on bool) error { + return p.packetConnIPv4.SetMulticastLoopback(on) +} + +func (p *packetConnIPv4) JoinGroup(ifi *net.Interface, group net.Addr) error { + return p.packetConnIPv4.JoinGroup(ifi, group) +} + +func (p *packetConnIPv4) LeaveGroup(ifi *net.Interface, group net.Addr) error { + return p.packetConnIPv4.LeaveGroup(ifi, group) +} + +type packetConnIPv6 struct { + packetConnIPv6 *ipv6.PacketConn +} + +func newPacketConnIPv6(p *ipv6.PacketConn) *packetConnIPv6 { + return &packetConnIPv6{p} +} + +func (p *packetConnIPv6) SetMulticastInterface(ifi *net.Interface) error { + return p.packetConnIPv6.SetMulticastInterface(ifi) +} + +func (p *packetConnIPv6) SetWriteDeadline(t time.Time) error { + return p.packetConnIPv6.SetWriteDeadline(t) +} + +func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) { + var c *ipv6.ControlMessage + if cm != nil { + c = &ipv6.ControlMessage{ + Src: cm.Src, + IfIndex: cm.IfIndex, + } + } + return p.packetConnIPv6.WriteTo(b, c, dst) +} + +func (p *packetConnIPv6) SetMulticastHopLimit(hoplim int) error { + return p.packetConnIPv6.SetMulticastHopLimit(hoplim) +} + +func (p *packetConnIPv6) SetMulticastLoopback(on bool) error { + return p.packetConnIPv6.SetMulticastLoopback(on) +} + +func (p *packetConnIPv6) JoinGroup(ifi *net.Interface, group net.Addr) error { + return p.packetConnIPv6.JoinGroup(ifi, group) +} + +func (p *packetConnIPv6) LeaveGroup(ifi *net.Interface, group net.Addr) error { + return p.packetConnIPv6.LeaveGroup(ifi, group) +} + +func (p *packetConnIPv6) SetControlMessage(on bool) error { + return p.packetConnIPv6.SetMulticastLoopback(on) +} + +// IsIPv6 return's true if addr is IPV6. +func IsIPv6(addr net.IP) bool { + if ip := addr.To16(); ip != nil && ip.To4() == nil { + return true + } + return false +} + +var DefaultUDPConnConfig = UDPConnConfig{ + Errors: func(err error) { + // don't log any error from fails for multicast requests + }, +} + +type UDPConnConfig struct { + Errors func(err error) +} + +func NewListenUDP(network, addr string, opts ...UDPOption) (*UDPConn, error) { + listenAddress, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP(network, listenAddress) + if err != nil { + return nil, err + } + return NewUDPConn(network, conn, opts...), nil +} + +// NewUDPConn creates connection over net.UDPConn. +func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn { + cfg := DefaultUDPConnConfig + for _, o := range opts { + o.ApplyUDP(&cfg) + } + + laddr := c.LocalAddr() + if laddr == nil { + panic(fmt.Errorf("invalid UDP connection")) + } + addr, ok := laddr.(*net.UDPAddr) + if !ok { + panic(fmt.Errorf("invalid address type(%T), UDP address expected", laddr)) + } + var pc packetConn + if IsIPv6(addr.IP) { + pc = newPacketConnIPv6(ipv6.NewPacketConn(c)) + } else { + pc = newPacketConnIPv4(ipv4.NewPacketConn(c)) + } + + return &UDPConn{ + network: network, + connection: c, + packetConn: pc, + errors: cfg.Errors, + } +} + +// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. +func (c *UDPConn) LocalAddr() net.Addr { + return c.connection.LocalAddr() +} + +// RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. +func (c *UDPConn) RemoteAddr() net.Addr { + return c.connection.RemoteAddr() +} + +// Network name of the network (for example, udp4, udp6, udp) +func (c *UDPConn) Network() string { + return c.network +} + +// Close closes the connection. +func (c *UDPConn) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return nil + } + return c.connection.Close() +} + +func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLimit int, raddr *net.UDPAddr, buffer []byte) error { + var pktSrc net.IP + var p packetConn + if IsIPv6(raddr.IP) { + p = newPacketConnIPv6(ipv6.NewPacketConn(c.connection)) + pktSrc = net.IPv6zero + } else { + p = newPacketConnIPv4(ipv4.NewPacketConn(c.connection)) + pktSrc = net.IPv4zero + } + if src != nil { + pktSrc = *src + } + + if c.closed.Load() { + return ErrConnectionIsClosed + } + if iface != nil { + if err := p.SetMulticastInterface(iface); err != nil { + return err + } + } + if err := p.SetMulticastHopLimit(multicastHopLimit); err != nil { + return err + } + + var err error + if iface != nil || src != nil { + _, err = p.WriteTo(buffer, &ControlMessage{ + Src: pktSrc, + IfIndex: iface.Index, + }, raddr) + } else { + _, err = p.WriteTo(buffer, nil, raddr) + } + return err +} + +func filterAddressesByNetwork(network string, ifaceAddrs []net.Addr) []net.Addr { + filtered := make([]net.Addr, 0, len(ifaceAddrs)) + for _, srcAddr := range ifaceAddrs { + addrMask := srcAddr.String() + addr := strings.Split(addrMask, "/")[0] + if strings.Contains(addr, ":") && network == "udp4" { + continue + } + if !strings.Contains(addr, ":") && network == "udp6" { + continue + } + filtered = append(filtered, srcAddr) + } + return filtered +} + +func convAddrsToIps(ifaceAddrs []net.Addr) []net.IP { + ips := make([]net.IP, 0, len(ifaceAddrs)) + for _, addr := range ifaceAddrs { + addrMask := addr.String() + addr := strings.Split(addrMask, "/")[0] + ip := net.ParseIP(addr) + if ip != nil { + ips = append(ips, ip) + } + } + return ips +} + +// WriteMulticast sends multicast to the remote multicast address. +// By default it is sent over all network interfaces and all compatible source IP addresses with hop limit 1. +// Via opts you can specify the network interface, source IP address, and hop limit. +func (c *UDPConn) WriteMulticast(ctx context.Context, raddr *net.UDPAddr, buffer []byte, opts ...MulticastOption) error { + opt := MulticastOptions{ + HopLimit: 1, + } + for _, o := range opts { + o.applyMC(&opt) + } + return c.writeMulticast(ctx, raddr, buffer, opt) +} + +func (c *UDPConn) writeMulticastWithInterface(raddr *net.UDPAddr, buffer []byte, opt MulticastOptions) error { + if opt.Iface == nil && opt.IFaceMode == MulticastSpecificInterface { + return fmt.Errorf("invalid interface") + } + if opt.Source != nil { + return c.writeToAddr(opt.Iface, opt.Source, opt.HopLimit, raddr, buffer) + } + ifaceAddrs, err := opt.Iface.Addrs() + if err != nil { + return err + } + netType := "udp4" + if IsIPv6(raddr.IP) { + netType = "udp6" + } + var errors []error + for _, ip := range convAddrsToIps(filterAddressesByNetwork(netType, ifaceAddrs)) { + ipAddr := ip + opt.Source = &ipAddr + err = c.writeToAddr(opt.Iface, opt.Source, opt.HopLimit, raddr, buffer) + if err != nil { + errors = append(errors, err) + } + } + if errors == nil { + return nil + } + if len(errors) == 1 { + return errors[0] + } + return fmt.Errorf("%v", errors) +} + +func (c *UDPConn) writeMulticastToAllInterfaces(raddr *net.UDPAddr, buffer []byte, opt MulticastOptions) error { + ifaces, err := net.Interfaces() + if err != nil { + return fmt.Errorf("cannot get interfaces for multicast connection: %w", err) + } + + var errors []error + for i := range ifaces { + iface := ifaces[i] + if iface.Flags&net.FlagMulticast == 0 { + continue + } + if iface.Flags&net.FlagUp != net.FlagUp { + continue + } + specificOpt := opt + specificOpt.Iface = &iface + specificOpt.IFaceMode = MulticastSpecificInterface + err = c.writeMulticastWithInterface(raddr, buffer, specificOpt) + if err != nil { + if opt.InterfaceError != nil { + opt.InterfaceError(&iface, err) + continue + } + errors = append(errors, err) + } + } + if errors == nil { + return nil + } + if len(errors) == 1 { + return errors[0] + } + return fmt.Errorf("%v", errors) +} + +func (c *UDPConn) validateMulticast(ctx context.Context, raddr *net.UDPAddr, opt MulticastOptions) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if raddr == nil { + return fmt.Errorf("cannot write multicast with context: invalid raddr") + } + if _, ok := c.packetConn.(*packetConnIPv4); ok && IsIPv6(raddr.IP) { + return fmt.Errorf("cannot write multicast with context: invalid destination address(%v)", raddr.IP) + } + if opt.Source != nil && IsIPv6(*opt.Source) && !IsIPv6(raddr.IP) { + return fmt.Errorf("cannot write multicast with context: invalid source address(%v) for destination(%v)", opt.Source, raddr.IP) + } + return nil +} + +func (c *UDPConn) writeMulticast(ctx context.Context, raddr *net.UDPAddr, buffer []byte, opt MulticastOptions) error { + err := c.validateMulticast(ctx, raddr, opt) + if err != nil { + return err + } + + switch opt.IFaceMode { + case MulticastAllInterface: + err := c.writeMulticastToAllInterfaces(raddr, buffer, opt) + if err != nil { + return fmt.Errorf("cannot write multicast to all interfaces: %w", err) + } + case MulticastAnyInterface: + err := c.writeToAddr(nil, opt.Source, opt.HopLimit, raddr, buffer) + if err != nil { + return fmt.Errorf("cannot write multicast to any: %w", err) + } + case MulticastSpecificInterface: + err := c.writeMulticastWithInterface(raddr, buffer, opt) + if err != nil { + if opt.InterfaceError != nil { + opt.InterfaceError(opt.Iface, err) + return nil + } + return fmt.Errorf("cannot write multicast to %v: %w", opt.Iface.Name, err) + } + } + return nil +} + +// WriteWithContext writes data with context. +func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buffer []byte) error { + if raddr == nil { + return fmt.Errorf("cannot write with context: invalid raddr") + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if c.closed.Load() { + return ErrConnectionIsClosed + } + n, err := WriteToUDP(c.connection, raddr, buffer) + if err != nil { + return err + } + if n != len(buffer) { + return ErrWriteInterrupted + } + + return nil +} + +// ReadWithContext reads packet with context. +func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *net.UDPAddr, error) { + select { + case <-ctx.Done(): + return -1, nil, ctx.Err() + default: + } + if c.closed.Load() { + return -1, nil, ErrConnectionIsClosed + } + n, s, err := c.connection.ReadFromUDP(buffer) + if err != nil { + return -1, nil, fmt.Errorf("cannot read from udp connection: %w", err) + } + return n, s, err +} + +// SetMulticastLoopback sets whether transmitted multicast packets +// should be copied and send back to the originator. +func (c *UDPConn) SetMulticastLoopback(on bool) error { + return c.packetConn.SetMulticastLoopback(on) +} + +// JoinGroup joins the group address group on the interface ifi. +// By default all sources that can cast data to group are accepted. +// It's possible to mute and unmute data transmission from a specific +// source by using ExcludeSourceSpecificGroup and +// IncludeSourceSpecificGroup. +// JoinGroup uses the system assigned multicast interface when ifi is +// nil, although this is not recommended because the assignment +// depends on platforms and sometimes it might require routing +// configuration. +func (c *UDPConn) JoinGroup(ifi *net.Interface, group net.Addr) error { + return c.packetConn.JoinGroup(ifi, group) +} + +// LeaveGroup leaves the group address group on the interface ifi +// regardless of whether the group is any-source group or source-specific group. +func (c *UDPConn) LeaveGroup(ifi *net.Interface, group net.Addr) error { + return c.packetConn.LeaveGroup(ifi, group) +} + +// NetConn returns the underlying connection that is wrapped by c. The Conn returned is shared by all invocations of NetConn, so do not modify it. +func (c *UDPConn) NetConn() net.Conn { + return c.connection +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/dtlslistener.go b/vendor/github.com/plgd-dev/go-coap/v3/net/dtlslistener.go new file mode 100644 index 0000000000..b797507f6c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/dtlslistener.go @@ -0,0 +1,197 @@ +package net + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + dtls "github.com/pion/dtls/v2" + dtlsnet "github.com/pion/dtls/v2/pkg/net" + "github.com/pion/dtls/v2/pkg/protocol" + "github.com/pion/dtls/v2/pkg/protocol/recordlayer" + "github.com/pion/transport/v3/udp" + "go.uber.org/atomic" +) + +type GoPoolFunc = func(f func()) error + +var DefaultDTLSListenerConfig = DTLSListenerConfig{ + GoPool: func(f func()) error { + go f() + return nil + }, +} + +type DTLSListenerConfig struct { + GoPool GoPoolFunc +} + +type acceptedConn struct { + conn net.Conn + err error +} + +// DTLSListener is a DTLS listener that provides accept with context. +type DTLSListener struct { + listener net.Listener + config *dtls.Config + closed atomic.Bool + goPool GoPoolFunc + acceptedConnChan chan acceptedConn + wg sync.WaitGroup + done chan struct{} +} + +func tlsPacketFilter(packet []byte) bool { + pkts, err := recordlayer.UnpackDatagram(packet) + if err != nil || len(pkts) < 1 { + return false + } + h := &recordlayer.Header{} + if err := h.Unmarshal(pkts[0]); err != nil { + return false + } + return h.ContentType == protocol.ContentTypeHandshake +} + +// NewDTLSListener creates dtls listener. +// Known networks are "udp", "udp4" (IPv4-only), "udp6" (IPv6-only). +func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config, opts ...DTLSListenerOption) (*DTLSListener, error) { + a, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, fmt.Errorf("cannot resolve address: %w", err) + } + cfg := DefaultDTLSListenerConfig + for _, o := range opts { + o.ApplyDTLS(&cfg) + } + + if cfg.GoPool == nil { + return nil, fmt.Errorf("empty go pool") + } + + l := DTLSListener{ + goPool: cfg.GoPool, + config: dtlsCfg, + acceptedConnChan: make(chan acceptedConn, 256), + done: make(chan struct{}), + } + connectContextMaker := dtlsCfg.ConnectContextMaker + if connectContextMaker == nil { + connectContextMaker = func() (context.Context, func()) { + return context.WithTimeout(context.Background(), 30*time.Second) + } + } + dtlsCfg.ConnectContextMaker = func() (context.Context, func()) { + ctx, cancel := connectContextMaker() + if l.closed.Load() { + cancel() + } + return ctx, cancel + } + + lc := udp.ListenConfig{ + AcceptFilter: tlsPacketFilter, + } + l.listener, err = lc.Listen(network, a) + if err != nil { + return nil, err + } + l.wg.Add(1) + go l.run() + return &l, nil +} + +func (l *DTLSListener) send(conn net.Conn, err error) { + select { + case <-l.done: + case l.acceptedConnChan <- acceptedConn{ + conn: conn, + err: err, + }: + } +} + +func (l *DTLSListener) accept() error { + c, err := l.listener.Accept() + if err != nil { + l.send(nil, err) + return err + } + err = l.goPool(func() { + l.send(dtls.Server(dtlsnet.PacketConnFromConn(c), c.RemoteAddr(), l.config)) + }) + if err != nil { + _ = c.Close() + } + return err +} + +func (l *DTLSListener) run() { + defer l.wg.Done() + for { + if l.closed.Load() { + return + } + err := l.accept() + if errors.Is(err, udp.ErrClosedListener) { + return + } + } +} + +// AcceptWithContext waits with context for a generic Conn. +func (l *DTLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if l.closed.Load() { + return nil, ErrListenerIsClosed + } + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.done: + return nil, ErrListenerIsClosed + case d := <-l.acceptedConnChan: + err := d.err + if errors.Is(err, context.DeadlineExceeded) { + // we don't want to report error handshake deadline exceeded + continue + } + if errors.Is(err, udp.ErrClosedListener) { + return nil, ErrListenerIsClosed + } + if err != nil { + return nil, err + } + return d.conn, nil + } + } +} + +// Accept waits for a generic Conn. +func (l *DTLSListener) Accept() (net.Conn, error) { + return l.AcceptWithContext(context.Background()) +} + +// Close closes the connection. +func (l *DTLSListener) Close() error { + if !l.closed.CompareAndSwap(false, true) { + return nil + } + close(l.done) + defer l.wg.Wait() + return l.listener.Close() +} + +// Addr represents a network end point address. +func (l *DTLSListener) Addr() net.Addr { + return l.listener.Addr() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/error.go b/vendor/github.com/plgd-dev/go-coap/v3/net/error.go new file mode 100644 index 0000000000..be564170a2 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/error.go @@ -0,0 +1,25 @@ +package net + +import ( + "context" + "errors" + "io" + "net" +) + +var ( + ErrListenerIsClosed = io.EOF + ErrConnectionIsClosed = io.EOF + ErrWriteInterrupted = errors.New("only part data was written to socket") +) + +func IsCancelOrCloseError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + // this error was produced by cancellation context or closing connection. + return true + } + return false +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/error_unix.go b/vendor/github.com/plgd-dev/go-coap/v3/net/error_unix.go new file mode 100644 index 0000000000..164f111f6c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/error_unix.go @@ -0,0 +1,16 @@ +//go:build aix || darwin || dragonfly || freebsd || js || linux || netbsd || openbsd || solaris +// +build aix darwin dragonfly freebsd js linux netbsd openbsd solaris + +package net + +import ( + "errors" + "syscall" +) + +// Check if error returned by operation on a socket failed because +// the other side has closed the connection. +func IsConnectionBrokenError(err error) bool { + return errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ECONNRESET) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/error_windows.go b/vendor/github.com/plgd-dev/go-coap/v3/net/error_windows.go new file mode 100644 index 0000000000..f745c9bc25 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/error_windows.go @@ -0,0 +1,14 @@ +//go:build windows +// +build windows + +package net + +import ( + "errors" + "syscall" +) + +func IsConnectionBrokenError(err error) bool { + return errors.Is(err, syscall.WSAECONNRESET) || + errors.Is(err, syscall.WSAECONNABORTED) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/monitor/inactivity/keepalive.go b/vendor/github.com/plgd-dev/go-coap/v3/net/monitor/inactivity/keepalive.go new file mode 100644 index 0000000000..dcb34fea90 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/monitor/inactivity/keepalive.go @@ -0,0 +1,63 @@ +package inactivity + +import ( + "unsafe" + + "go.uber.org/atomic" +) + +type cancelPingFunc func() + +type KeepAlive[C Conn] struct { + pongToken atomic.Uint64 + onInactive OnInactiveFunc[C] + + sendPing func(cc C, receivePong func()) (func(), error) + cancelPing atomic.UnsafePointer + numFails atomic.Uint32 + + maxRetries uint32 +} + +func NewKeepAlive[C Conn](maxRetries uint32, onInactive OnInactiveFunc[C], sendPing func(cc C, receivePong func()) (func(), error)) *KeepAlive[C] { + return &KeepAlive[C]{ + maxRetries: maxRetries, + sendPing: sendPing, + onInactive: onInactive, + } +} + +func (m *KeepAlive[C]) checkCancelPing() { + cancelPingPtr := m.cancelPing.Swap(nil) + if cancelPingPtr != nil { + cancelPing := *(*cancelPingFunc)(cancelPingPtr) + cancelPing() + } +} + +func (m *KeepAlive[C]) OnInactive(cc C) { + v := m.incrementFails() + m.checkCancelPing() + if v > m.maxRetries { + m.onInactive(cc) + return + } + pongToken := m.pongToken.Add(1) + cancel, err := m.sendPing(cc, func() { + if m.pongToken.Load() == pongToken { + m.resetFails() + } + }) + if err != nil { + return + } + m.cancelPing.Store(unsafe.Pointer(&cancel)) +} + +func (m *KeepAlive[C]) incrementFails() uint32 { + return m.numFails.Add(1) +} + +func (m *KeepAlive[C]) resetFails() { + m.numFails.Store(0) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/monitor/inactivity/monitor.go b/vendor/github.com/plgd-dev/go-coap/v3/net/monitor/inactivity/monitor.go new file mode 100644 index 0000000000..33c020217a --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/monitor/inactivity/monitor.go @@ -0,0 +1,68 @@ +package inactivity + +import ( + "context" + "sync/atomic" + "time" +) + +type OnInactiveFunc[C Conn] func(cc C) + +type Conn = interface { + Context() context.Context + Close() error +} + +type Monitor[C Conn] struct { + lastActivity atomic.Value + duration time.Duration + onInactive OnInactiveFunc[C] +} + +func (m *Monitor[C]) Notify() { + m.lastActivity.Store(time.Now()) +} + +func (m *Monitor[C]) LastActivity() time.Time { + if t, ok := m.lastActivity.Load().(time.Time); ok { + return t + } + return time.Time{} +} + +func CloseConn(cc Conn) { + // call cc.Close() directly to check and handle error if necessary + _ = cc.Close() +} + +func New[C Conn](duration time.Duration, onInactive OnInactiveFunc[C]) *Monitor[C] { + m := &Monitor[C]{ + duration: duration, + onInactive: onInactive, + } + m.Notify() + return m +} + +func (m *Monitor[C]) CheckInactivity(now time.Time, cc C) { + if m.onInactive == nil || m.duration == time.Duration(0) { + return + } + if now.After(m.LastActivity().Add(m.duration)) { + m.onInactive(cc) + } +} + +type NilMonitor[C Conn] struct{} + +func (m *NilMonitor[C]) CheckInactivity(time.Time, C) { + // do nothing +} + +func (m *NilMonitor[C]) Notify() { + // do nothing +} + +func NewNilMonitor[C Conn]() *NilMonitor[C] { + return &NilMonitor[C]{} +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/observation/handler.go b/vendor/github.com/plgd-dev/go-coap/v3/net/observation/handler.go new file mode 100644 index 0000000000..f170269f54 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/observation/handler.go @@ -0,0 +1,267 @@ +package observation + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/pkg/errors" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "go.uber.org/atomic" +) + +type DoFunc = func(req *pool.Message) (*pool.Message, error) + +type Client interface { + Context() context.Context + WriteMessage(req *pool.Message) error + ReleaseMessage(msg *pool.Message) + AcquireMessage(ctx context.Context) *pool.Message +} + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as COAP handlers. +type HandlerFunc[C Client] func(*responsewriter.ResponseWriter[C], *pool.Message) + +type Handler[C Client] struct { + cc C + observations *coapSync.Map[uint64, *Observation[C]] + next HandlerFunc[C] + do DoFunc +} + +func (h *Handler[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message) { + if o, ok := h.observations.Load(r.Token().Hash()); ok { + o.handle(r) + return + } + h.next(w, r) +} + +func (h *Handler[C]) client() C { + return h.cc +} + +func (h *Handler[C]) NewObservation(req *pool.Message, observeFunc func(req *pool.Message)) (*Observation[C], error) { + observe, err := req.Observe() + if err != nil { + return nil, fmt.Errorf("cannot get observe option: %w", err) + } + if observe != 0 { + return nil, fmt.Errorf("invalid value of observe(%v): expected 0", observe) + } + token := req.Token() + if len(token) == 0 { + return nil, fmt.Errorf("empty token") + } + options, err := req.Options().Clone() + if err != nil { + return nil, fmt.Errorf("cannot clone options: %w", err) + } + respObservationChan := make(chan respObservationMessage, 1) + o := newObservation(message.Message{ + Token: req.Token(), + Code: req.Code(), + Options: options, + }, h, observeFunc, respObservationChan) + defer func(err *error) { + if *err != nil { + o.cleanUp() + } + }(&err) + if _, loaded := h.observations.LoadOrStore(token.Hash(), o); loaded { + err = errors.ErrKeyAlreadyExists + return nil, err + } + + err = h.cc.WriteMessage(req) + if err != nil { + return nil, err + } + select { + case <-req.Context().Done(): + err = req.Context().Err() + return nil, err + case <-h.cc.Context().Done(): + err = fmt.Errorf("connection was closed: %w", h.cc.Context().Err()) + return nil, err + case resp := <-respObservationChan: + if resp.code != codes.Content && resp.code != codes.Valid { + err = fmt.Errorf("unexpected return code(%v)", resp.code) + return nil, err + } + if resp.notSupported { + o.cleanUp() + } + return o, nil + } +} + +func (h *Handler[C]) GetObservation(key uint64) (*Observation[C], bool) { + return h.observations.Load(key) +} + +// GetObservationRequest returns observation request for token +func (h *Handler[C]) GetObservationRequest(token message.Token) (*pool.Message, bool) { + obs, ok := h.GetObservation(token.Hash()) + if !ok { + return nil, false + } + req := obs.Request() + msg := h.cc.AcquireMessage(h.cc.Context()) + msg.ResetOptionsTo(req.Options) + msg.SetCode(req.Code) + msg.SetToken(req.Token) + return msg, true +} + +func (h *Handler[C]) pullOutObservation(key uint64) (*Observation[C], bool) { + return h.observations.LoadAndDelete(key) +} + +func NewHandler[C Client](cc C, next HandlerFunc[C], do DoFunc) *Handler[C] { + return &Handler[C]{ + cc: cc, + observations: coapSync.NewMap[uint64, *Observation[C]](), + next: next, + do: do, + } +} + +type respObservationMessage struct { + code codes.Code + notSupported bool +} + +// Observation represents subscription to resource on the server +type Observation[C Client] struct { + req message.Message + observeFunc func(req *pool.Message) + respObservationChan chan respObservationMessage + waitForResponse atomic.Bool + observationHandler *Handler[C] + + private struct { // members guarded by mutex + mutex sync.Mutex + obsSequence uint32 + lastEvent time.Time + etag []byte + } +} + +func (o *Observation[C]) Canceled() bool { + _, ok := o.observationHandler.GetObservation(o.req.Token.Hash()) + return !ok +} + +func newObservation[C Client](req message.Message, observationHandler *Handler[C], observeFunc func(req *pool.Message), respObservationChan chan respObservationMessage) *Observation[C] { + return &Observation[C]{ + req: req, + waitForResponse: *atomic.NewBool(true), + respObservationChan: respObservationChan, + observeFunc: observeFunc, + observationHandler: observationHandler, + } +} + +func (o *Observation[C]) handle(r *pool.Message) { + if o.waitForResponse.CompareAndSwap(true, false) { + select { + case o.respObservationChan <- respObservationMessage{ + code: r.Code(), + notSupported: !r.HasOption(message.Observe), + }: + default: + } + o.respObservationChan = nil + } + if o.wantBeNotified(r) { + o.observeFunc(r) + } +} + +func (o *Observation[C]) cleanUp() bool { + // we can ignore err during cleanUp, if err != nil then some other + // part of code already removed the handler for the token + _, ok := o.observationHandler.pullOutObservation(o.req.Token.Hash()) + return ok +} + +func (o *Observation[C]) client() C { + return o.observationHandler.client() +} + +func (o *Observation[C]) Request() message.Message { + return o.req +} + +func (o *Observation[C]) etag() []byte { + o.private.mutex.Lock() + defer o.private.mutex.Unlock() + return o.private.etag +} + +// Cancel remove observation from server. For recreate observation use Observe. +func (o *Observation[C]) Cancel(ctx context.Context, opts ...message.Option) error { + if !o.cleanUp() { + // observation was already cleanup + return nil + } + + req := o.client().AcquireMessage(ctx) + defer o.client().ReleaseMessage(req) + req.ResetOptionsTo(opts) + req.SetCode(codes.GET) + req.SetObserve(1) + if path, err := o.req.Options.Path(); err == nil { + if err := req.SetPath(path); err != nil { + return fmt.Errorf("cannot set path(%v): %w", path, err) + } + } + req.SetToken(o.req.Token) + etag := o.etag() + if len(etag) > 0 { + _ = req.SetETag(etag) // ignore invalid etag + } + resp, err := o.observationHandler.do(req) + if err != nil { + return err + } + defer o.client().ReleaseMessage(resp) + if resp.Code() != codes.Content && resp.Code() != codes.Valid { + return fmt.Errorf("unexpected return code(%v)", resp.Code()) + } + return nil +} + +func (o *Observation[C]) wantBeNotified(r *pool.Message) bool { + obsSequence, err := r.Observe() + if err != nil { + return true + } + now := time.Now() + + o.private.mutex.Lock() + defer o.private.mutex.Unlock() + if !ValidSequenceNumber(o.private.obsSequence, obsSequence, o.private.lastEvent, now) { + return false + } + + o.private.obsSequence = obsSequence + o.private.lastEvent = now + if etag, err := r.ETag(); err == nil { + if cap(o.private.etag) < len(etag) { + o.private.etag = make([]byte, len(etag)) + } + if len(o.private.etag) != len(etag) { + o.private.etag = o.private.etag[:len(etag)] + } + copy(o.private.etag, etag) + } + return true +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/observation/observation.go b/vendor/github.com/plgd-dev/go-coap/v3/net/observation/observation.go new file mode 100644 index 0000000000..074d3156d1 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/observation/observation.go @@ -0,0 +1,22 @@ +package observation + +import ( + "time" +) + +// ObservationSequenceTimeout defines how long is sequence number is valid. https://tools.ietf.org/html/rfc7641#section-3.4 +const ObservationSequenceTimeout = 128 * time.Second + +// ValidSequenceNumber implements conditions in https://tools.ietf.org/html/rfc7641#section-3.4 +func ValidSequenceNumber(oldValue, newValue uint32, lastEventOccurs time.Time, now time.Time) bool { + if oldValue < newValue && (newValue-oldValue) < (1<<23) { + return true + } + if oldValue > newValue && (oldValue-newValue) > (1<<23) { + return true + } + if now.Sub(lastEventOccurs) > ObservationSequenceTimeout { + return true + } + return false +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/options.go b/vendor/github.com/plgd-dev/go-coap/v3/net/options.go new file mode 100644 index 0000000000..c037ef9c28 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/options.go @@ -0,0 +1,143 @@ +package net + +import "net" + +// A UDPOption sets options such as errors parameters, etc. +type UDPOption interface { + ApplyUDP(*UDPConnConfig) +} + +type ErrorsOpt struct { + errors func(err error) +} + +func (h ErrorsOpt) ApplyUDP(o *UDPConnConfig) { + o.Errors = h.errors +} + +func WithErrors(v func(err error)) ErrorsOpt { + return ErrorsOpt{ + errors: v, + } +} + +func DefaultMulticastOptions() MulticastOptions { + return MulticastOptions{ + IFaceMode: MulticastAllInterface, + HopLimit: 1, + } +} + +type MulticastInterfaceMode int + +const ( + MulticastAllInterface MulticastInterfaceMode = 0 + MulticastAnyInterface MulticastInterfaceMode = 1 + MulticastSpecificInterface MulticastInterfaceMode = 2 +) + +type InterfaceError = func(iface *net.Interface, err error) + +type MulticastOptions struct { + IFaceMode MulticastInterfaceMode + Iface *net.Interface + Source *net.IP + HopLimit int + InterfaceError InterfaceError +} + +func (m *MulticastOptions) Apply(o MulticastOption) { + o.applyMC(m) +} + +// A MulticastOption sets options such as hop limit, etc. +type MulticastOption interface { + applyMC(*MulticastOptions) +} + +type MulticastInterfaceModeOpt struct { + mode MulticastInterfaceMode +} + +func (m MulticastInterfaceModeOpt) applyMC(o *MulticastOptions) { + o.IFaceMode = m.mode +} + +func WithAnyMulticastInterface() MulticastOption { + return MulticastInterfaceModeOpt{mode: MulticastAnyInterface} +} + +func WithAllMulticastInterface() MulticastOption { + return MulticastInterfaceModeOpt{mode: MulticastAllInterface} +} + +type MulticastInterfaceOpt struct { + iface net.Interface +} + +func (m MulticastInterfaceOpt) applyMC(o *MulticastOptions) { + o.Iface = &m.iface + o.IFaceMode = MulticastSpecificInterface +} + +func WithMulticastInterface(iface net.Interface) MulticastOption { + return &MulticastInterfaceOpt{iface: iface} +} + +type MulticastHoplimitOpt struct { + hoplimit int +} + +func (m MulticastHoplimitOpt) applyMC(o *MulticastOptions) { + o.HopLimit = m.hoplimit +} + +func WithMulticastHoplimit(hoplimit int) MulticastOption { + return &MulticastHoplimitOpt{hoplimit: hoplimit} +} + +type MulticastSourceOpt struct { + source net.IP +} + +func (m MulticastSourceOpt) applyMC(o *MulticastOptions) { + o.Source = &m.source +} + +func WithMulticastSource(source net.IP) MulticastOption { + return &MulticastSourceOpt{source: source} +} + +type MulticastInterfaceErrorOpt struct { + interfaceError InterfaceError +} + +func (m MulticastInterfaceErrorOpt) applyMC(o *MulticastOptions) { + o.InterfaceError = m.interfaceError +} + +// WithMulticastInterfaceError sets the callback for interface errors. If it is set error is not propagated as result of WriteMulticast. +func WithMulticastInterfaceError(interfaceError InterfaceError) MulticastOption { + return &MulticastInterfaceErrorOpt{interfaceError: interfaceError} +} + +// A DTLSListenerOption sets options such as gopool. +type DTLSListenerOption interface { + ApplyDTLS(*DTLSListenerConfig) +} + +// GoPoolOpt gopool option. +type GoPoolOpt struct { + goPool GoPoolFunc +} + +func (o GoPoolOpt) ApplyDTLS(cfg *DTLSListenerConfig) { + cfg.GoPool = o.goPool +} + +// WithGoPool sets function for managing spawning go routines +// for handling incoming request's. +// Eg: https://github.com/panjf2000/ants. +func WithGoPool(goPool GoPoolFunc) GoPoolOpt { + return GoPoolOpt{goPool: goPool} +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/responsewriter/responseWriter.go b/vendor/github.com/plgd-dev/go-coap/v3/net/responsewriter/responseWriter.go new file mode 100644 index 0000000000..44fe2d9f70 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/responsewriter/responseWriter.go @@ -0,0 +1,79 @@ +package responsewriter + +import ( + "io" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/noresponse" + "github.com/plgd-dev/go-coap/v3/message/pool" +) + +type Client interface { + ReleaseMessage(msg *pool.Message) +} + +// A ResponseWriter is used by an COAP handler to construct an COAP response. +type ResponseWriter[C Client] struct { + noResponseValue *uint32 + response *pool.Message + cc C +} + +func New[C Client](response *pool.Message, cc C, requestOptions ...message.Option) *ResponseWriter[C] { + var noResponseValue *uint32 + if len(requestOptions) > 0 { + reqOpts := message.Options(requestOptions) + v, err := reqOpts.GetUint32(message.NoResponse) + if err == nil { + noResponseValue = &v + } + } + + return &ResponseWriter[C]{ + response: response, + cc: cc, + noResponseValue: noResponseValue, + } +} + +// SetResponse simplifies the setup of the response for the request. ETags must be set via options. For advanced setup, use Message(). +func (r *ResponseWriter[C]) SetResponse(code codes.Code, contentFormat message.MediaType, d io.ReadSeeker, opts ...message.Option) error { + if r.noResponseValue != nil { + err := noresponse.IsNoResponseCode(code, *r.noResponseValue) + if err != nil { + return err + } + } + + r.response.SetCode(code) + r.response.ResetOptionsTo(opts) + if d != nil { + r.response.SetContentFormat(contentFormat) + r.response.SetBody(d) + } + return nil +} + +// SetMessage replaces the response message. The original message was released to the message pool, so don't use it any more. Ensure that Token, MessageID(udp), and Type(udp) messages are paired correctly. +func (r *ResponseWriter[C]) SetMessage(m *pool.Message) { + r.cc.ReleaseMessage(r.response) + r.response = m +} + +// Message direct access to the response. +func (r *ResponseWriter[C]) Message() *pool.Message { + return r.response +} + +// Swap message in response without releasing. +func (r *ResponseWriter[C]) Swap(m *pool.Message) *pool.Message { + tmp := r.response + r.response = m + return tmp +} + +// CConn peer connection. +func (r *ResponseWriter[C]) Conn() C { + return r.cc +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/tcplistener.go b/vendor/github.com/plgd-dev/go-coap/v3/net/tcplistener.go new file mode 100644 index 0000000000..43cc955ef0 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/tcplistener.go @@ -0,0 +1,69 @@ +package net + +import ( + "context" + "fmt" + "net" + + "go.uber.org/atomic" +) + +// TCPListener is a TCP network listener that provides accept with context. +type TCPListener struct { + listener *net.TCPListener + closed atomic.Bool +} + +func newNetTCPListen(network string, addr string) (*net.TCPListener, error) { + a, err := net.ResolveTCPAddr(network, addr) + if err != nil { + return nil, fmt.Errorf("cannot create new net tcp listener: %w", err) + } + + tcp, err := net.ListenTCP(network, a) + if err != nil { + return nil, fmt.Errorf("cannot create new net tcp listener: %w", err) + } + return tcp, nil +} + +// NewTCPListener creates tcp listener. +// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only). +func NewTCPListener(network string, addr string) (*TCPListener, error) { + tcp, err := newNetTCPListen(network, addr) + if err != nil { + return nil, fmt.Errorf("cannot create new tcp listener: %w", err) + } + return &TCPListener{listener: tcp}, nil +} + +// AcceptWithContext waits with context for a generic Conn. +func (l *TCPListener) AcceptWithContext(ctx context.Context) (net.Conn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if l.closed.Load() { + return nil, ErrListenerIsClosed + } + return l.listener.Accept() +} + +// Accept waits for a generic Conn. +func (l *TCPListener) Accept() (net.Conn, error) { + return l.AcceptWithContext(context.Background()) +} + +// Close closes the connection. +func (l *TCPListener) Close() error { + if !l.closed.CompareAndSwap(false, true) { + return nil + } + return l.listener.Close() +} + +// Addr represents a network end point address. +func (l *TCPListener) Addr() net.Addr { + return l.listener.Addr() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/tlslistener.go b/vendor/github.com/plgd-dev/go-coap/v3/net/tlslistener.go new file mode 100644 index 0000000000..df12b1a66c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/tlslistener.go @@ -0,0 +1,66 @@ +package net + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "go.uber.org/atomic" +) + +// TLSListener is a TLS listener that provides accept with context. +type TLSListener struct { + listener net.Listener + tcp *net.TCPListener + closed atomic.Bool +} + +// NewTLSListener creates tcp listener. +// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only). +func NewTLSListener(network string, addr string, tlsCfg *tls.Config) (*TLSListener, error) { + tcp, err := newNetTCPListen(network, addr) + if err != nil { + return nil, fmt.Errorf("cannot create new tls listener: %w", err) + } + tls := tls.NewListener(tcp, tlsCfg) + return &TLSListener{ + tcp: tcp, + listener: tls, + }, nil +} + +// AcceptWithContext waits with context for a generic Conn. +func (l *TLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if l.closed.Load() { + return nil, ErrListenerIsClosed + } + rw, err := l.listener.Accept() + if err != nil { + return nil, err + } + return rw, nil +} + +// Accept waits for a generic Conn. +func (l *TLSListener) Accept() (net.Conn, error) { + return l.AcceptWithContext(context.Background()) +} + +// Close closes the connection. +func (l *TLSListener) Close() error { + if !l.closed.CompareAndSwap(false, true) { + return nil + } + return l.listener.Close() +} + +// Addr represents a network end point address. +func (l *TLSListener) Addr() net.Addr { + return l.listener.Addr() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/net/udp.go b/vendor/github.com/plgd-dev/go-coap/v3/net/udp.go new file mode 100644 index 0000000000..0459266cf8 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/net/udp.go @@ -0,0 +1,15 @@ +package net + +import ( + "net" +) + +// WriteToUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. +func WriteToUDP(conn *net.UDPConn, raddr *net.UDPAddr, b []byte) (int, error) { + if conn.RemoteAddr() == nil { + // Connection remote address must be nil otherwise + // "WriteTo with pre-connected connection" will be thrown + return conn.WriteToUDP(b, raddr) + } + return conn.Write(b) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/options/commonOptions.go b/vendor/github.com/plgd-dev/go-coap/v3/options/commonOptions.go new file mode 100644 index 0000000000..535bb8bcb8 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/options/commonOptions.go @@ -0,0 +1,786 @@ +package options + +import ( + "context" + "fmt" + "net" + "time" + + dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/mux" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/client" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options/config" + "github.com/plgd-dev/go-coap/v3/pkg/runner/periodic" + tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client" + tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server" + udpClient "github.com/plgd-dev/go-coap/v3/udp/client" + udpServer "github.com/plgd-dev/go-coap/v3/udp/server" +) + +type ErrorFunc = config.ErrorFunc + +type Handler interface { + tcpClient.HandlerFunc | udpClient.HandlerFunc +} + +// HandlerFuncOpt handler function option. +type HandlerFuncOpt[H Handler] struct { + h H +} + +func panicForInvalidHandlerFunc(t, exp any) { + panic(fmt.Errorf("invalid HandlerFunc type %T, expected %T", t, exp)) +} + +func (o HandlerFuncOpt[H]) TCPServerApply(cfg *tcpServer.Config) { + switch v := any(o.h).(type) { + case tcpClient.HandlerFunc: + cfg.Handler = v + default: + var exp tcpClient.HandlerFunc + panicForInvalidHandlerFunc(v, exp) + } +} + +func (o HandlerFuncOpt[H]) TCPClientApply(cfg *tcpClient.Config) { + switch v := any(o.h).(type) { + case tcpClient.HandlerFunc: + cfg.Handler = v + default: + var exp tcpClient.HandlerFunc + panicForInvalidHandlerFunc(v, exp) + } +} + +func (o HandlerFuncOpt[H]) UDPServerApply(cfg *udpServer.Config) { + switch v := any(o.h).(type) { + case udpClient.HandlerFunc: + cfg.Handler = v + default: + var exp udpClient.HandlerFunc + panicForInvalidHandlerFunc(v, exp) + } +} + +func (o HandlerFuncOpt[H]) DTLSServerApply(cfg *dtlsServer.Config) { + switch v := any(o.h).(type) { + case udpClient.HandlerFunc: + cfg.Handler = v + default: + var exp udpClient.HandlerFunc + panicForInvalidHandlerFunc(v, exp) + } +} + +func (o HandlerFuncOpt[H]) UDPClientApply(cfg *udpClient.Config) { + switch v := any(o.h).(type) { + case udpClient.HandlerFunc: + cfg.Handler = v + default: + var t udpClient.HandlerFunc + panicForInvalidHandlerFunc(v, t) + } +} + +// WithHandlerFunc set handle for handling request's. +func WithHandlerFunc[H Handler](h H) HandlerFuncOpt[H] { + return HandlerFuncOpt[H]{h: h} +} + +// HandlerFuncOpt handler function option. +type MuxHandlerOpt struct { + m mux.Handler +} + +func (o MuxHandlerOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.Handler = mux.ToHandler[*tcpClient.Conn](o.m) +} + +func (o MuxHandlerOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.Handler = mux.ToHandler[*tcpClient.Conn](o.m) +} + +func (o MuxHandlerOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.Handler = mux.ToHandler[*udpClient.Conn](o.m) +} + +func (o MuxHandlerOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.Handler = mux.ToHandler[*udpClient.Conn](o.m) +} + +func (o MuxHandlerOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.Handler = mux.ToHandler[*udpClient.Conn](o.m) +} + +// WithMux set's multiplexer for handle requests. +func WithMux(m mux.Handler) MuxHandlerOpt { + return MuxHandlerOpt{ + m: m, + } +} + +// ContextOpt handler function option. +type ContextOpt struct { + ctx context.Context +} + +func (o ContextOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.Ctx = o.ctx +} + +func (o ContextOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.Ctx = o.ctx +} + +func (o ContextOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.Ctx = o.ctx +} + +func (o ContextOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.Ctx = o.ctx +} + +func (o ContextOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.Ctx = o.ctx +} + +// WithContext set's parent context of server. +func WithContext(ctx context.Context) ContextOpt { + return ContextOpt{ctx: ctx} +} + +// MaxMessageSizeOpt handler function option. +type MaxMessageSizeOpt struct { + maxMessageSize uint32 +} + +func (o MaxMessageSizeOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.MaxMessageSize = o.maxMessageSize +} + +func (o MaxMessageSizeOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.MaxMessageSize = o.maxMessageSize +} + +func (o MaxMessageSizeOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.MaxMessageSize = o.maxMessageSize +} + +func (o MaxMessageSizeOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.MaxMessageSize = o.maxMessageSize +} + +func (o MaxMessageSizeOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.MaxMessageSize = o.maxMessageSize +} + +// WithMaxMessageSize limit size of processed message. +func WithMaxMessageSize(maxMessageSize uint32) MaxMessageSizeOpt { + return MaxMessageSizeOpt{maxMessageSize: maxMessageSize} +} + +// ErrorsOpt errors option. +type ErrorsOpt struct { + errors ErrorFunc +} + +func (o ErrorsOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.Errors = o.errors +} + +func (o ErrorsOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.Errors = o.errors +} + +func (o ErrorsOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.Errors = o.errors +} + +func (o ErrorsOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.Errors = o.errors +} + +func (o ErrorsOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.Errors = o.errors +} + +// WithErrors set function for logging error. +func WithErrors(errors ErrorFunc) ErrorsOpt { + return ErrorsOpt{errors: errors} +} + +// ProcessReceivedMessageOpt gopool option. +type ProcessReceivedMessageOpt[C responsewriter.Client] struct { + ProcessReceivedMessageFunc config.ProcessReceivedMessageFunc[C] +} + +func panicForInvalidProcessReceivedMessageFunc(t, exp any) { + panic(fmt.Errorf("invalid ProcessReceivedMessageFunc type %T, expected %T", t, exp)) +} + +func (o ProcessReceivedMessageOpt[C]) TCPServerApply(cfg *tcpServer.Config) { + switch v := any(o.ProcessReceivedMessageFunc).(type) { + case config.ProcessReceivedMessageFunc[*tcpClient.Conn]: + cfg.ProcessReceivedMessage = v + default: + var t config.ProcessReceivedMessageFunc[*tcpClient.Conn] + panicForInvalidProcessReceivedMessageFunc(v, t) + } +} + +func (o ProcessReceivedMessageOpt[C]) TCPClientApply(cfg *tcpClient.Config) { + switch v := any(o.ProcessReceivedMessageFunc).(type) { + case config.ProcessReceivedMessageFunc[*tcpClient.Conn]: + cfg.ProcessReceivedMessage = v + default: + var t config.ProcessReceivedMessageFunc[*tcpClient.Conn] + panicForInvalidProcessReceivedMessageFunc(v, t) + } +} + +func (o ProcessReceivedMessageOpt[C]) UDPServerApply(cfg *udpServer.Config) { + switch v := any(o.ProcessReceivedMessageFunc).(type) { + case config.ProcessReceivedMessageFunc[*udpClient.Conn]: + cfg.ProcessReceivedMessage = v + default: + var t config.ProcessReceivedMessageFunc[*udpClient.Conn] + panicForInvalidProcessReceivedMessageFunc(v, t) + } +} + +func (o ProcessReceivedMessageOpt[C]) DTLSServerApply(cfg *dtlsServer.Config) { + switch v := any(o.ProcessReceivedMessageFunc).(type) { + case config.ProcessReceivedMessageFunc[*udpClient.Conn]: + cfg.ProcessReceivedMessage = v + default: + var t config.ProcessReceivedMessageFunc[*udpClient.Conn] + panicForInvalidProcessReceivedMessageFunc(v, t) + } +} + +func (o ProcessReceivedMessageOpt[C]) UDPClientApply(cfg *udpClient.Config) { + switch v := any(o.ProcessReceivedMessageFunc).(type) { + case config.ProcessReceivedMessageFunc[*udpClient.Conn]: + cfg.ProcessReceivedMessage = v + default: + var t config.ProcessReceivedMessageFunc[*udpClient.Conn] + panicForInvalidProcessReceivedMessageFunc(v, t) + } +} + +func WithProcessReceivedMessageFunc[C responsewriter.Client](processReceivedMessageFunc config.ProcessReceivedMessageFunc[C]) ProcessReceivedMessageOpt[C] { + return ProcessReceivedMessageOpt[C]{ProcessReceivedMessageFunc: processReceivedMessageFunc} +} + +type ( + UDPOnInactive = func(cc *udpClient.Conn) + TCPOnInactive = func(cc *tcpClient.Conn) +) + +type OnInactiveFunc interface { + UDPOnInactive | TCPOnInactive +} + +func panicForInvalidOnInactiveFunc(t, exp any) { + panic(fmt.Errorf("invalid OnInactiveFunc type %T, expected %T", t, exp)) +} + +// KeepAliveOpt keepalive option. +type KeepAliveOpt[C OnInactiveFunc] struct { + timeout time.Duration + onInactive C + maxRetries uint32 +} + +func (o KeepAliveOpt[C]) toTCPCreateInactivityMonitor(onInactive TCPOnInactive) func() tcpClient.InactivityMonitor { + return func() tcpClient.InactivityMonitor { + keepalive := inactivity.NewKeepAlive(o.maxRetries, onInactive, func(cc *tcpClient.Conn, receivePong func()) (func(), error) { + return cc.AsyncPing(receivePong) + }) + return inactivity.New(o.timeout/time.Duration(o.maxRetries+1), keepalive.OnInactive) + } +} + +func (o KeepAliveOpt[C]) toUDPCreateInactivityMonitor(onInactive UDPOnInactive) func() udpClient.InactivityMonitor { + return func() udpClient.InactivityMonitor { + keepalive := inactivity.NewKeepAlive(o.maxRetries, onInactive, func(cc *udpClient.Conn, receivePong func()) (func(), error) { + return cc.AsyncPing(receivePong) + }) + return inactivity.New(o.timeout/time.Duration(o.maxRetries+1), keepalive.OnInactive) + } +} + +func (o KeepAliveOpt[C]) TCPServerApply(cfg *tcpServer.Config) { + switch onInactive := any(o.onInactive).(type) { + case TCPOnInactive: + cfg.CreateInactivityMonitor = o.toTCPCreateInactivityMonitor(onInactive) + default: + var t TCPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o KeepAliveOpt[C]) TCPClientApply(cfg *tcpClient.Config) { + switch onInactive := any(o.onInactive).(type) { + case TCPOnInactive: + cfg.CreateInactivityMonitor = o.toTCPCreateInactivityMonitor(onInactive) + default: + var t TCPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o KeepAliveOpt[C]) UDPServerApply(cfg *udpServer.Config) { + switch onInactive := any(o.onInactive).(type) { + case UDPOnInactive: + cfg.CreateInactivityMonitor = o.toUDPCreateInactivityMonitor(onInactive) + default: + var t UDPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o KeepAliveOpt[C]) DTLSServerApply(cfg *dtlsServer.Config) { + switch onInactive := any(o.onInactive).(type) { + case UDPOnInactive: + cfg.CreateInactivityMonitor = o.toUDPCreateInactivityMonitor(onInactive) + default: + var t UDPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o KeepAliveOpt[C]) UDPClientApply(cfg *udpClient.Config) { + switch onInactive := any(o.onInactive).(type) { + case UDPOnInactive: + cfg.CreateInactivityMonitor = o.toUDPCreateInactivityMonitor(onInactive) + default: + var t UDPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +// WithKeepAlive monitoring's client connection's. +func WithKeepAlive[C OnInactiveFunc](maxRetries uint32, timeout time.Duration, onInactive C) KeepAliveOpt[C] { + return KeepAliveOpt[C]{ + maxRetries: maxRetries, + timeout: timeout, + onInactive: onInactive, + } +} + +// InactivityMonitorOpt notifies when a connection was inactive for a given duration. +type InactivityMonitorOpt[C OnInactiveFunc] struct { + duration time.Duration + onInactive C +} + +func (o InactivityMonitorOpt[C]) toTCPCreateInactivityMonitor(onInactive TCPOnInactive) func() tcpClient.InactivityMonitor { + return func() tcpClient.InactivityMonitor { + return inactivity.New(o.duration, onInactive) + } +} + +func (o InactivityMonitorOpt[C]) toUDPCreateInactivityMonitor(onInactive UDPOnInactive) func() udpClient.InactivityMonitor { + return func() udpClient.InactivityMonitor { + return inactivity.New(o.duration, onInactive) + } +} + +func (o InactivityMonitorOpt[C]) TCPServerApply(cfg *tcpServer.Config) { + switch onInactive := any(o.onInactive).(type) { + case TCPOnInactive: + cfg.CreateInactivityMonitor = o.toTCPCreateInactivityMonitor(onInactive) + default: + var t TCPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o InactivityMonitorOpt[C]) TCPClientApply(cfg *tcpClient.Config) { + switch onInactive := any(o.onInactive).(type) { + case TCPOnInactive: + cfg.CreateInactivityMonitor = o.toTCPCreateInactivityMonitor(onInactive) + default: + var t TCPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o InactivityMonitorOpt[C]) UDPServerApply(cfg *udpServer.Config) { + switch onInactive := any(o.onInactive).(type) { + case UDPOnInactive: + cfg.CreateInactivityMonitor = o.toUDPCreateInactivityMonitor(onInactive) + default: + var t UDPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o InactivityMonitorOpt[C]) DTLSServerApply(cfg *dtlsServer.Config) { + switch onInactive := any(o.onInactive).(type) { + case UDPOnInactive: + cfg.CreateInactivityMonitor = o.toUDPCreateInactivityMonitor(onInactive) + default: + var t UDPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +func (o InactivityMonitorOpt[C]) UDPClientApply(cfg *udpClient.Config) { + switch onInactive := any(o.onInactive).(type) { + case UDPOnInactive: + cfg.CreateInactivityMonitor = o.toUDPCreateInactivityMonitor(onInactive) + default: + var t UDPOnInactive + panicForInvalidOnInactiveFunc(onInactive, t) + } +} + +// WithInactivityMonitor set deadline's for read operations over client connection. +func WithInactivityMonitor[C OnInactiveFunc](duration time.Duration, onInactive C) InactivityMonitorOpt[C] { + return InactivityMonitorOpt[C]{ + duration: duration, + onInactive: onInactive, + } +} + +// NetOpt network option. +type NetOpt struct { + net string +} + +func (o NetOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.Net = o.net +} + +func (o NetOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.Net = o.net +} + +// WithNetwork define's tcp version (udp4, udp6, tcp) for client. +func WithNetwork(net string) NetOpt { + return NetOpt{net: net} +} + +// PeriodicRunnerOpt function which is executed in every ticks +type PeriodicRunnerOpt struct { + periodicRunner periodic.Func +} + +func (o PeriodicRunnerOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.PeriodicRunner = o.periodicRunner +} + +func (o PeriodicRunnerOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.PeriodicRunner = o.periodicRunner +} + +func (o PeriodicRunnerOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.PeriodicRunner = o.periodicRunner +} + +func (o PeriodicRunnerOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.PeriodicRunner = o.periodicRunner +} + +func (o PeriodicRunnerOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.PeriodicRunner = o.periodicRunner +} + +// WithPeriodicRunner set function which is executed in every ticks. +func WithPeriodicRunner(periodicRunner periodic.Func) PeriodicRunnerOpt { + return PeriodicRunnerOpt{periodicRunner: periodicRunner} +} + +// BlockwiseOpt network option. +type BlockwiseOpt struct { + transferTimeout time.Duration + enable bool + szx blockwise.SZX +} + +func (o BlockwiseOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.BlockwiseEnable = o.enable + cfg.BlockwiseSZX = o.szx + cfg.BlockwiseTransferTimeout = o.transferTimeout +} + +func (o BlockwiseOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.BlockwiseEnable = o.enable + cfg.BlockwiseSZX = o.szx + cfg.BlockwiseTransferTimeout = o.transferTimeout +} + +func (o BlockwiseOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.BlockwiseEnable = o.enable + cfg.BlockwiseSZX = o.szx + cfg.BlockwiseTransferTimeout = o.transferTimeout +} + +func (o BlockwiseOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.BlockwiseEnable = o.enable + cfg.BlockwiseSZX = o.szx + cfg.BlockwiseTransferTimeout = o.transferTimeout +} + +func (o BlockwiseOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.BlockwiseEnable = o.enable + cfg.BlockwiseSZX = o.szx + cfg.BlockwiseTransferTimeout = o.transferTimeout +} + +// WithBlockwise configure's blockwise transfer. +func WithBlockwise(enable bool, szx blockwise.SZX, transferTimeout time.Duration) BlockwiseOpt { + return BlockwiseOpt{ + enable: enable, + szx: szx, + transferTimeout: transferTimeout, + } +} + +type OnNewConnFunc interface { + tcpServer.OnNewConnFunc | udpServer.OnNewConnFunc +} + +// OnNewConnOpt network option. +type OnNewConnOpt[F OnNewConnFunc] struct { + f F +} + +func panicForInvalidOnNewConnFunc(t, exp any) { + panic(fmt.Errorf("invalid OnNewConnFunc type %T, expected %T", t, exp)) +} + +func (o OnNewConnOpt[F]) UDPServerApply(cfg *udpServer.Config) { + switch v := any(o.f).(type) { + case udpServer.OnNewConnFunc: + cfg.OnNewConn = v + default: + var exp udpServer.OnNewConnFunc + panicForInvalidOnNewConnFunc(v, exp) + } +} + +func (o OnNewConnOpt[F]) DTLSServerApply(cfg *dtlsServer.Config) { + switch v := any(o.f).(type) { + case udpServer.OnNewConnFunc: + cfg.OnNewConn = v + default: + var exp udpServer.OnNewConnFunc + panicForInvalidOnNewConnFunc(v, exp) + } +} + +func (o OnNewConnOpt[F]) TCPServerApply(cfg *tcpServer.Config) { + switch v := any(o.f).(type) { + case tcpServer.OnNewConnFunc: + cfg.OnNewConn = v + default: + var exp tcpServer.OnNewConnFunc + panicForInvalidOnNewConnFunc(v, exp) + } +} + +// WithOnNewConn server's notify about new client connection. +func WithOnNewConn[F OnNewConnFunc](onNewConn F) OnNewConnOpt[F] { + return OnNewConnOpt[F]{ + f: onNewConn, + } +} + +// CloseSocketOpt close socket option. +type CloseSocketOpt struct{} + +func (o CloseSocketOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.CloseSocket = true +} + +func (o CloseSocketOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.CloseSocket = true +} + +// WithCloseSocket closes socket at the close connection. +func WithCloseSocket() CloseSocketOpt { + return CloseSocketOpt{} +} + +// DialerOpt dialer option. +type DialerOpt struct { + dialer *net.Dialer +} + +func (o DialerOpt) UDPClientApply(cfg *udpClient.Config) { + if o.dialer != nil { + cfg.Dialer = o.dialer + } +} + +func (o DialerOpt) TCPClientApply(cfg *tcpClient.Config) { + if o.dialer != nil { + cfg.Dialer = o.dialer + } +} + +// WithDialer set dialer for dial. +func WithDialer(dialer *net.Dialer) DialerOpt { + return DialerOpt{ + dialer: dialer, + } +} + +// ConnectionCacheOpt network option. +type MessagePoolOpt struct { + messagePool *pool.Pool +} + +func (o MessagePoolOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.MessagePool = o.messagePool +} + +func (o MessagePoolOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.MessagePool = o.messagePool +} + +func (o MessagePoolOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.MessagePool = o.messagePool +} + +func (o MessagePoolOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.MessagePool = o.messagePool +} + +func (o MessagePoolOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.MessagePool = o.messagePool +} + +// WithMessagePool configure's message pool for acquire/releasing coap messages +func WithMessagePool(messagePool *pool.Pool) MessagePoolOpt { + return MessagePoolOpt{ + messagePool: messagePool, + } +} + +// GetTokenOpt token option. +type GetTokenOpt struct { + getToken client.GetTokenFunc +} + +func (o GetTokenOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.GetToken = o.getToken +} + +func (o GetTokenOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.GetToken = o.getToken +} + +func (o GetTokenOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.GetToken = o.getToken +} + +func (o GetTokenOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.GetToken = o.getToken +} + +func (o GetTokenOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.GetToken = o.getToken +} + +// WithGetToken set function for generating tokens. +func WithGetToken(getToken client.GetTokenFunc) GetTokenOpt { + return GetTokenOpt{getToken: getToken} +} + +// LimitClientParallelRequestOpt limit's number of parallel requests from client. +type LimitClientParallelRequestOpt struct { + limitClientParallelRequests int64 +} + +func (o LimitClientParallelRequestOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.LimitClientParallelRequests = o.limitClientParallelRequests +} + +func (o LimitClientParallelRequestOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.LimitClientParallelRequests = o.limitClientParallelRequests +} + +func (o LimitClientParallelRequestOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.LimitClientParallelRequests = o.limitClientParallelRequests +} + +func (o LimitClientParallelRequestOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.LimitClientParallelRequests = o.limitClientParallelRequests +} + +func (o LimitClientParallelRequestOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.LimitClientParallelRequests = o.limitClientParallelRequests +} + +// WithLimitClientParallelRequestOpt limits number of parallel requests from client. (default: 1) +func WithLimitClientParallelRequest(limitClientParallelRequests int64) LimitClientParallelRequestOpt { + return LimitClientParallelRequestOpt{limitClientParallelRequests: limitClientParallelRequests} +} + +// LimitClientEndpointParallelRequestOpt limit's number of parallel requests to endpoint by client. +type LimitClientEndpointParallelRequestOpt struct { + limitClientEndpointParallelRequests int64 +} + +func (o LimitClientEndpointParallelRequestOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.LimitClientEndpointParallelRequests = o.limitClientEndpointParallelRequests +} + +func (o LimitClientEndpointParallelRequestOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.LimitClientEndpointParallelRequests = o.limitClientEndpointParallelRequests +} + +func (o LimitClientEndpointParallelRequestOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.LimitClientEndpointParallelRequests = o.limitClientEndpointParallelRequests +} + +func (o LimitClientEndpointParallelRequestOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.LimitClientEndpointParallelRequests = o.limitClientEndpointParallelRequests +} + +func (o LimitClientEndpointParallelRequestOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.LimitClientEndpointParallelRequests = o.limitClientEndpointParallelRequests +} + +// WithLimitClientEndpointParallelRequest limits number of parallel requests to endpoint from client. (default: 1) +func WithLimitClientEndpointParallelRequest(limitClientEndpointParallelRequests int64) LimitClientEndpointParallelRequestOpt { + return LimitClientEndpointParallelRequestOpt{limitClientEndpointParallelRequests: limitClientEndpointParallelRequests} +} + +// ReceivedMessageQueueSizeOpt limit's message queue size for received messages. +type ReceivedMessageQueueSizeOpt struct { + receivedMessageQueueSize int +} + +func (o ReceivedMessageQueueSizeOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.ReceivedMessageQueueSize = o.receivedMessageQueueSize +} + +func (o ReceivedMessageQueueSizeOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.ReceivedMessageQueueSize = o.receivedMessageQueueSize +} + +func (o ReceivedMessageQueueSizeOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.ReceivedMessageQueueSize = o.receivedMessageQueueSize +} + +func (o ReceivedMessageQueueSizeOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.ReceivedMessageQueueSize = o.receivedMessageQueueSize +} + +func (o ReceivedMessageQueueSizeOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.ReceivedMessageQueueSize = o.receivedMessageQueueSize +} + +// WithReceivedMessageQueueSize limit's message queue size for received messages. (default: 16) +func WithReceivedMessageQueueSize(receivedMessageQueueSize int) ReceivedMessageQueueSizeOpt { + return ReceivedMessageQueueSizeOpt{receivedMessageQueueSize: receivedMessageQueueSize} +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/options/config/common.go b/vendor/github.com/plgd-dev/go-coap/v3/options/config/common.go new file mode 100644 index 0000000000..e6fd876a31 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/options/config/common.go @@ -0,0 +1,61 @@ +package config + +import ( + "context" + "fmt" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/client" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/pkg/runner/periodic" +) + +type ( + ErrorFunc = func(error) + HandlerFunc[C responsewriter.Client] func(w *responsewriter.ResponseWriter[C], r *pool.Message) + ProcessReceivedMessageFunc[C responsewriter.Client] func(req *pool.Message, cc C, handler HandlerFunc[C]) +) + +type Common[C responsewriter.Client] struct { + LimitClientParallelRequests int64 + LimitClientEndpointParallelRequests int64 + Ctx context.Context + Errors ErrorFunc + PeriodicRunner periodic.Func + MessagePool *pool.Pool + GetToken client.GetTokenFunc + MaxMessageSize uint32 + BlockwiseTransferTimeout time.Duration + BlockwiseSZX blockwise.SZX + BlockwiseEnable bool + ProcessReceivedMessage ProcessReceivedMessageFunc[C] + ReceivedMessageQueueSize int +} + +func NewCommon[C responsewriter.Client]() Common[C] { + return Common[C]{ + Ctx: context.Background(), + MaxMessageSize: 64 * 1024, + Errors: func(err error) { + fmt.Println(err) + }, + BlockwiseSZX: blockwise.SZX1024, + BlockwiseEnable: true, + BlockwiseTransferTimeout: time.Second * 3, + PeriodicRunner: func(f func(now time.Time) bool) { + go func() { + for f(time.Now()) { + time.Sleep(4 * time.Second) + } + }() + }, + MessagePool: pool.New(1024, 2048), + GetToken: message.GetToken, + LimitClientParallelRequests: 1, + LimitClientEndpointParallelRequests: 1, + ReceivedMessageQueueSize: 16, + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/options/tcpOptions.go b/vendor/github.com/plgd-dev/go-coap/v3/options/tcpOptions.go new file mode 100644 index 0000000000..ec2f9fb69f --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/options/tcpOptions.go @@ -0,0 +1,76 @@ +package options + +import ( + "crypto/tls" + + tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client" + tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server" +) + +// DisablePeerTCPSignalMessageCSMsOpt coap-tcp csm option. +type DisablePeerTCPSignalMessageCSMsOpt struct{} + +func (o DisablePeerTCPSignalMessageCSMsOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.DisablePeerTCPSignalMessageCSMs = true +} + +func (o DisablePeerTCPSignalMessageCSMsOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.DisablePeerTCPSignalMessageCSMs = true +} + +// WithDisablePeerTCPSignalMessageCSMs ignor peer's CSM message. +func WithDisablePeerTCPSignalMessageCSMs() DisablePeerTCPSignalMessageCSMsOpt { + return DisablePeerTCPSignalMessageCSMsOpt{} +} + +// DisableTCPSignalMessageCSMOpt coap-tcp csm option. +type DisableTCPSignalMessageCSMOpt struct{} + +func (o DisableTCPSignalMessageCSMOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.DisableTCPSignalMessageCSM = true +} + +func (o DisableTCPSignalMessageCSMOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.DisableTCPSignalMessageCSM = true +} + +// WithDisableTCPSignalMessageCSM don't send CSM when client conn is created. +func WithDisableTCPSignalMessageCSM() DisableTCPSignalMessageCSMOpt { + return DisableTCPSignalMessageCSMOpt{} +} + +// TLSOpt tls configuration option. +type TLSOpt struct { + tlsCfg *tls.Config +} + +func (o TLSOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.TLSCfg = o.tlsCfg +} + +// WithTLS creates tls connection. +func WithTLS(cfg *tls.Config) TLSOpt { + return TLSOpt{ + tlsCfg: cfg, + } +} + +// ConnectionCacheOpt network option. +type ConnectionCacheSizeOpt struct { + connectionCacheSize uint16 +} + +func (o ConnectionCacheSizeOpt) TCPServerApply(cfg *tcpServer.Config) { + cfg.ConnectionCacheSize = o.connectionCacheSize +} + +func (o ConnectionCacheSizeOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.ConnectionCacheSize = o.connectionCacheSize +} + +// WithConnectionCacheSize configure's maximum size of cache of read buffer. +func WithConnectionCacheSize(connectionCacheSize uint16) ConnectionCacheSizeOpt { + return ConnectionCacheSizeOpt{ + connectionCacheSize: connectionCacheSize, + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/options/udpOptions.go b/vendor/github.com/plgd-dev/go-coap/v3/options/udpOptions.go new file mode 100644 index 0000000000..77f4ff931c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/options/udpOptions.go @@ -0,0 +1,70 @@ +package options + +import ( + "time" + + dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server" + udpClient "github.com/plgd-dev/go-coap/v3/udp/client" + udpServer "github.com/plgd-dev/go-coap/v3/udp/server" +) + +// TransmissionOpt transmission options. +type TransmissionOpt struct { + transmissionNStart uint32 + transmissionAcknowledgeTimeout time.Duration + transmissionMaxRetransmit uint32 +} + +func (o TransmissionOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.TransmissionNStart = o.transmissionNStart + cfg.TransmissionAcknowledgeTimeout = o.transmissionAcknowledgeTimeout + cfg.TransmissionMaxRetransmit = o.transmissionMaxRetransmit +} + +func (o TransmissionOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.TransmissionNStart = o.transmissionNStart + cfg.TransmissionAcknowledgeTimeout = o.transmissionAcknowledgeTimeout + cfg.TransmissionMaxRetransmit = o.transmissionMaxRetransmit +} + +func (o TransmissionOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.TransmissionNStart = o.transmissionNStart + cfg.TransmissionAcknowledgeTimeout = o.transmissionAcknowledgeTimeout + cfg.TransmissionMaxRetransmit = o.transmissionMaxRetransmit +} + +// WithTransmission set options for (re)transmission for Confirmable message-s. +func WithTransmission(transmissionNStart uint32, + transmissionAcknowledgeTimeout time.Duration, + transmissionMaxRetransmit uint32, +) TransmissionOpt { + return TransmissionOpt{ + transmissionNStart: transmissionNStart, + transmissionAcknowledgeTimeout: transmissionAcknowledgeTimeout, + transmissionMaxRetransmit: transmissionMaxRetransmit, + } +} + +// MTUOpt transmission options. +type MTUOpt struct { + mtu uint16 +} + +func (o MTUOpt) UDPServerApply(cfg *udpServer.Config) { + cfg.MTU = o.mtu +} + +func (o MTUOpt) DTLSServerApply(cfg *dtlsServer.Config) { + cfg.MTU = o.mtu +} + +func (o MTUOpt) UDPClientApply(cfg *udpClient.Config) { + cfg.MTU = o.mtu +} + +// Setup MTU unit +func WithMTU(mtu uint16) MTUOpt { + return MTUOpt{ + mtu: mtu, + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/pkg/cache/cache.go b/vendor/github.com/plgd-dev/go-coap/v3/pkg/cache/cache.go new file mode 100644 index 0000000000..21099fc313 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/pkg/cache/cache.go @@ -0,0 +1,85 @@ +package cache + +import ( + "time" + + "github.com/plgd-dev/go-coap/v3/pkg/sync" + "go.uber.org/atomic" +) + +func DefaultOnExpire[D any](D) { + // for nothing on expire +} + +type Element[D any] struct { + ValidUntil atomic.Time + data D + onExpire func(d D) +} + +func (e *Element[D]) IsExpired(now time.Time) bool { + value := e.ValidUntil.Load() + if value.IsZero() { + return false + } + return now.After(value) +} + +func (e *Element[D]) Data() D { + return e.data +} + +func NewElement[D any](data D, validUntil time.Time, onExpire func(d D)) *Element[D] { + if onExpire == nil { + onExpire = DefaultOnExpire[D] + } + e := &Element[D]{data: data, onExpire: onExpire} + e.ValidUntil.Store(validUntil) + return e +} + +type Cache[K comparable, D any] struct { + *sync.Map[K, *Element[D]] +} + +func NewCache[K comparable, D any]() *Cache[K, D] { + return &Cache[K, D]{ + Map: sync.NewMap[K, *Element[D]](), + } +} + +func (c *Cache[K, D]) LoadOrStore(key K, e *Element[D]) (actual *Element[D], loaded bool) { + now := time.Now() + c.Map.ReplaceWithFunc(key, func(oldValue *Element[D], oldLoaded bool) (newValue *Element[D], deleteValue bool) { + if oldLoaded { + if !oldValue.IsExpired(now) { + actual = oldValue + return oldValue, false + } + } + actual = e + return e, false + }) + return actual, actual != e +} + +func (c *Cache[K, D]) Load(key K) (actual *Element[D]) { + actual, loaded := c.Map.Load(key) + if !loaded { + return nil + } + if actual.IsExpired(time.Now()) { + return nil + } + return actual +} + +func (c *Cache[K, D]) CheckExpirations(now time.Time) { + c.Range(func(key K, value *Element[D]) bool { + if value.IsExpired(now) { + c.Map.Delete(key) + value.onExpire(value.Data()) + } + return true + }) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/pkg/connections/connections.go b/vendor/github.com/plgd-dev/go-coap/v3/pkg/connections/connections.go new file mode 100644 index 0000000000..d2de86579c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/pkg/connections/connections.go @@ -0,0 +1,73 @@ +package connections + +import ( + "context" + "fmt" + "net" + "sync" + "time" +) + +type Connections struct { + data *sync.Map +} + +func New() *Connections { + return &Connections{ + data: &sync.Map{}, + } +} + +type Connection interface { + Context() context.Context + CheckExpirations(now time.Time) + Close() error + RemoteAddr() net.Addr +} + +func (c *Connections) Store(conn Connection) { + c.data.Store(conn.RemoteAddr().String(), conn) +} + +func (c *Connections) length() int { + var l int + c.data.Range(func(k, v interface{}) bool { + l++ + return true + }) + return l +} + +func (c *Connections) copyConnections() []Connection { + m := make([]Connection, 0, c.length()) + c.data.Range(func(key, value interface{}) bool { + con, ok := value.(Connection) + if !ok { + panic(fmt.Errorf("invalid type %T in connections map", con)) + } + m = append(m, con) + return true + }) + return m +} + +func (c *Connections) CheckExpirations(now time.Time) { + for _, cc := range c.copyConnections() { + select { + case <-cc.Context().Done(): + continue + default: + cc.CheckExpirations(now) + } + } +} + +func (c *Connections) Close() { + for _, cc := range c.copyConnections() { + _ = cc.Close() + } +} + +func (c *Connections) Delete(conn Connection) { + c.data.Delete(conn.RemoteAddr().String()) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/pkg/errors/error.go b/vendor/github.com/plgd-dev/go-coap/v3/pkg/errors/error.go new file mode 100644 index 0000000000..e98098b58c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/pkg/errors/error.go @@ -0,0 +1,5 @@ +package errors + +import "errors" + +var ErrKeyAlreadyExists = errors.New("key already exists") diff --git a/vendor/github.com/plgd-dev/go-coap/v3/pkg/fn/funcList.go b/vendor/github.com/plgd-dev/go-coap/v3/pkg/fn/funcList.go new file mode 100644 index 0000000000..9ed31e801b --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/pkg/fn/funcList.go @@ -0,0 +1,19 @@ +package fn + +type FuncList []func() + +// Return a function that executions all added functions +// +// Functions are executed in reverse order they were added. +func (c FuncList) ToFunction() func() { + return func() { + for i := range c { + c[len(c)-1-i]() + } + } +} + +// Execute all added functions +func (c FuncList) Execute() { + c.ToFunction()() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/pkg/rand/rand.go b/vendor/github.com/plgd-dev/go-coap/v3/pkg/rand/rand.go new file mode 100644 index 0000000000..124d3132b5 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/pkg/rand/rand.go @@ -0,0 +1,31 @@ +package rand + +import ( + "math/rand" + "sync" +) + +type Rand struct { + src *rand.Rand + lock sync.Mutex +} + +func NewRand(seed int64) *Rand { + return &Rand{ + src: rand.New(rand.NewSource(seed)), + } +} + +func (l *Rand) Int63() int64 { + l.lock.Lock() + val := l.src.Int63() + l.lock.Unlock() + return val +} + +func (l *Rand) Uint32() uint32 { + l.lock.Lock() + val := l.src.Uint32() + l.lock.Unlock() + return val +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/pkg/runner/periodic/periodic.go b/vendor/github.com/plgd-dev/go-coap/v3/pkg/runner/periodic/periodic.go new file mode 100644 index 0000000000..7cfe624577 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/pkg/runner/periodic/periodic.go @@ -0,0 +1,43 @@ +package periodic + +import ( + "sync" + "sync/atomic" + "time" +) + +type Func = func(f func(now time.Time) bool) + +func New(stop <-chan struct{}, tick time.Duration) Func { + var m sync.Map + var idx uint64 + go func() { + t := time.NewTicker(tick) + defer t.Stop() + for { + var now time.Time + select { + case now = <-t.C: + case <-stop: + return + } + v := make(map[uint64]func(time.Time) bool) + m.Range(func(key, value interface{}) bool { + v[key.(uint64)] = value.(func(time.Time) bool) //nolint:forcetypeassert + return true + }) + for k, f := range v { + if ok := f(now); !ok { + m.Delete(k) + } + } + } + }() + return func(f func(time.Time) bool) { + if f == nil { + return + } + v := atomic.AddUint64(&idx, 1) + m.Store(v, f) + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/pkg/sync/map.go b/vendor/github.com/plgd-dev/go-coap/v3/pkg/sync/map.go new file mode 100644 index 0000000000..321aefb373 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/pkg/sync/map.go @@ -0,0 +1,215 @@ +package sync + +import ( + "sync" + + "golang.org/x/exp/maps" +) + +// Map is like a Go map[interface{}]interface{} but is safe for concurrent use by multiple goroutines. +type Map[K comparable, V any] struct { + mutex sync.RWMutex + data map[K]V +} + +// NewMap creates map. +func NewMap[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + data: make(map[K]V), + } +} + +// Store sets the value for a key. +func (m *Map[K, V]) Store(key K, value V) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.data[key] = value +} + +// Load returns the value stored in the map for a key, or nil if no value is present. The loaded value is read-only and should not be modified. +// The ok result indicates whether value was found in the map. +func (m *Map[K, V]) Load(key K) (V, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + v, ok := m.data[key] + return v, ok +} + +// LoadOrStore returns the existing value for the key if present. The loaded value is read-only and should not be modified. +// Otherwise, it stores and returns the given value. The loaded result is true if the value was loaded, false if stored. +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + m.mutex.RLock() + v, ok := m.data[key] + m.mutex.RUnlock() + if ok { + return v, true + } + m.mutex.Lock() + m.data[key] = value + m.mutex.Unlock() + return value, false +} + +// Replace replaces the existing value with a new value and returns old value for the key. +func (m *Map[K, V]) Replace(key K, value V) (oldValue V, oldLoaded bool) { + m.mutex.Lock() + defer m.mutex.Unlock() + v, ok := m.data[key] + m.data[key] = value + return v, ok +} + +// Delete deletes the value for the key. +func (m *Map[K, V]) Delete(key K) { + m.mutex.Lock() + defer m.mutex.Unlock() + delete(m.data, key) +} + +// LoadAndDelete loads and deletes the value for the key. +func (m *Map[K, V]) LoadAndDelete(key K) (V, bool) { + m.mutex.Lock() + defer m.mutex.Unlock() + value, ok := m.data[key] + delete(m.data, key) + return value, ok +} + +// LoadAndDelete loads and deletes the value for the key. +func (m *Map[K, V]) LoadAndDeleteAll() map[K]V { + m.mutex.Lock() + data := m.data + m.data = make(map[K]V) + m.mutex.Unlock() + return data +} + +// CopyData creates a deep copy of the internal map. +func (m *Map[K, V]) CopyData() map[K]V { + c := make(map[K]V) + m.mutex.RLock() + maps.Copy(c, m.data) + m.mutex.RUnlock() + return c +} + +// Length returns number of stored values. +func (m *Map[K, V]) Length() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + return len(m.data) +} + +// Range calls f sequentially for each key and value present in the map. If f returns false, range stops the iteration. +// +// Range does not copy the whole map, instead the read lock is locked on iteration of the map, and unlocked before f is called. +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + for key, value := range m.data { + m.mutex.RUnlock() + ok := f(key, value) + m.mutex.RLock() + if !ok { + return + } + } +} + +// Range2 calls f sequentially for each key and value present in the map. If f returns false, range stops the iteration. +// +// Range2 differs from Range by keepting a read lock locked during the whole call. +func (m *Map[K, V]) Range2(f func(key K, value V) bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + for key, value := range m.data { + ok := f(key, value) + if !ok { + return + } + } +} + +// StoreWithFunc creates a new element and stores it in the map under the given key. +// +// The createFunc is invoked under a write lock. +func (m *Map[K, V]) StoreWithFunc(key K, createFunc func() V) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.data[key] = createFunc() +} + +// LoadWithFunc tries to load element for key from the map, if it exists then the onload functions is invoked on it. +// +// The onLoadFunc is invoked under a read lock. +func (m *Map[K, V]) LoadWithFunc(key K, onLoadFunc func(value V) V) (V, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + value, ok := m.data[key] + if ok && onLoadFunc != nil { + value = onLoadFunc(value) + } + return value, ok +} + +// LoadOrStoreWithFunc loads an existing element from the map or creates a new element and stores it in the map +// +// The onLoadFunc or createFunc are invoked under a write lock. +func (m *Map[K, V]) LoadOrStoreWithFunc(key K, onLoadFunc func(value V) V, createFunc func() V) (actual V, loaded bool) { + m.mutex.Lock() + defer m.mutex.Unlock() + v, ok := m.data[key] + if ok { + if onLoadFunc != nil { + v = onLoadFunc(v) + } + return v, true + } + v = createFunc() + m.data[key] = v + return v, false +} + +// ReplaceWithFunc checks whether key exists in the map, invokes the onReplaceFunc callback on the pair (value, ok) and either deletes or stores the element +// in the map based on the returned values from the onReplaceFunc callback. +// +// The onReplaceFunc callback is invoked under a write lock. +func (m *Map[K, V]) ReplaceWithFunc(key K, onReplaceFunc func(oldValue V, oldLoaded bool) (newValue V, doDelete bool)) (oldValue V, oldLoaded bool) { + m.mutex.Lock() + defer m.mutex.Unlock() + v, ok := m.data[key] + newValue, del := onReplaceFunc(v, ok) + if del { + delete(m.data, key) + return v, ok + } + m.data[key] = newValue + return v, ok +} + +// DeleteWithFunc removes the key from the map and if a value existed invokes the onDeleteFunc callback on the removed value. +// +// The onDeleteFunc callback is invoked under a write lock. +func (m *Map[K, V]) DeleteWithFunc(key K, onDeleteFunc func(value V)) { + _, _ = m.LoadAndDeleteWithFunc(key, func(value V) V { + onDeleteFunc(value) + return value + }) +} + +// LoadAndDeleteWithFunc removes the key from the map and if a value existed invokes the onLoadFunc callback on the removed and return it. +// +// The onLoadFunc callback is invoked under a write lock. +func (m *Map[K, V]) LoadAndDeleteWithFunc(key K, onLoadFunc func(value V) V) (V, bool) { + var v V + var loaded bool + m.ReplaceWithFunc(key, func(oldValue V, oldLoaded bool) (newValue V, doDelete bool) { + if oldLoaded { + loaded = true + v = onLoadFunc(oldValue) + return v, true + } + return oldValue, true + }) + return v, loaded +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/renovate.json b/vendor/github.com/plgd-dev/go-coap/v3/renovate.json new file mode 100644 index 0000000000..f844c2ed8b --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/renovate.json @@ -0,0 +1,15 @@ +{ + "extends": [ + "config:base" + ], + "postUpdateOptions": [ + "gomodTidy" + ], + "commitBody": "Generated by renovateBot", + "packageRules": [ + { + "packagePatterns": [".+"], + "schedule": ["on the first day of the month"] + } + ] +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/server.go b/vendor/github.com/plgd-dev/go-coap/v3/server.go new file mode 100644 index 0000000000..1df5ab6904 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/server.go @@ -0,0 +1,161 @@ +// Package coap provides a CoAP client and server. +package coap + +import ( + "crypto/tls" + "fmt" + + piondtls "github.com/pion/dtls/v2" + "github.com/plgd-dev/go-coap/v3/dtls" + dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server" + "github.com/plgd-dev/go-coap/v3/mux" + "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/options" + "github.com/plgd-dev/go-coap/v3/tcp" + tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server" + "github.com/plgd-dev/go-coap/v3/udp" + udpServer "github.com/plgd-dev/go-coap/v3/udp/server" +) + +// ListenAndServe Starts a server on address and network specified Invoke handler +// for incoming queries. +func ListenAndServe(network string, addr string, handler mux.Handler) (err error) { + switch network { + case "udp", "udp4", "udp6", "": + l, err := net.NewListenUDP(network, addr) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := udp.NewServer(options.WithMux(handler)) + return s.Serve(l) + case "tcp", "tcp4", "tcp6": + l, err := net.NewTCPListener(network, addr) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := tcp.NewServer(options.WithMux(handler)) + return s.Serve(l) + default: + return fmt.Errorf("invalid network (%v)", network) + } +} + +// ListenAndServeTCPTLS Starts a server on address and network over TLS specified Invoke handler +// for incoming queries. +func ListenAndServeTCPTLS(network, addr string, config *tls.Config, handler mux.Handler) (err error) { + l, err := net.NewTLSListener(network, addr, config) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := tcp.NewServer(options.WithMux(handler)) + return s.Serve(l) +} + +// ListenAndServeDTLS Starts a server on address and network over DTLS specified Invoke handler +// for incoming queries. +func ListenAndServeDTLS(network string, addr string, config *piondtls.Config, handler mux.Handler) (err error) { + l, err := net.NewDTLSListener(network, addr, config) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := dtls.NewServer(options.WithMux(handler)) + return s.Serve(l) +} + +// ListenAndServeWithOption Starts a server on address and network specified Invoke options +// for incoming queries. The options is only support tcpServer.Option and udpServer.Option +func ListenAndServeWithOptions(network, addr string, opts ...any) (err error) { + tcpOptions := []tcpServer.Option{} + udpOptions := []udpServer.Option{} + for _, opt := range opts { + switch o := opt.(type) { + case tcpServer.Option: + tcpOptions = append(tcpOptions, o) + case udpServer.Option: + udpOptions = append(udpOptions, o) + default: + return fmt.Errorf("only support tcpServer.Option and udpServer.Option") + } + } + + switch network { + case "udp", "udp4", "udp6", "": + l, err := net.NewListenUDP(network, addr) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := udp.NewServer(udpOptions...) + return s.Serve(l) + case "tcp", "tcp4", "tcp6": + l, err := net.NewTCPListener(network, addr) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := tcp.NewServer(tcpOptions...) + return s.Serve(l) + default: + return fmt.Errorf("invalid network (%v)", network) + } +} + +// ListenAndServeTCPTLSWithOptions Starts a server on address and network over TLS specified Invoke options +// for incoming queries. +func ListenAndServeTCPTLSWithOptions(network, addr string, config *tls.Config, opts ...tcpServer.Option) (err error) { + l, err := net.NewTLSListener(network, addr, config) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := tcp.NewServer(opts...) + return s.Serve(l) +} + +// ListenAndServeDTLSWithOptions Starts a server on address and network over DTLS specified Invoke options +// for incoming queries. +func ListenAndServeDTLSWithOptions(network string, addr string, config *piondtls.Config, opts ...dtlsServer.Option) (err error) { + l, err := net.NewDTLSListener(network, addr, config) + if err != nil { + return err + } + defer func() { + if errC := l.Close(); errC != nil && err == nil { + err = errC + } + }() + s := dtls.NewServer(opts...) + return s.Serve(l) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/sonar-project.properties b/vendor/github.com/plgd-dev/go-coap/v3/sonar-project.properties new file mode 100644 index 0000000000..b1b8facfc2 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/sonar-project.properties @@ -0,0 +1,24 @@ +sonar.projectKey=plgd-dev_go-coap +sonar.organization=plgd-dev + +# This is the name and version displayed in the SonarCloud UI. +#sonar.projectName=hub +#sonar.projectVersion=1.0 + +#sonar.log.level=DEBUG +#sonar.verbose=true + +sonar.python.version=3.8 + +sonar.sources=. +sonar.exclusions=**/*_test.go,**/*.pb.go,**/*.pb.gw.go,**/options.go,**/main.go,v3/** + +sonar.tests=. +sonar.test.inclusions=**/*_test.go +sonar.test.exclusions= + +#wildcard do not work for tests.reportPaths +#sonar.go.tests.reportPaths=.tmp/report/certificate-authority.report.json,.tmp/report/cloud2cloud-connector.report.json,.tmp/report/cloud2cloud-gateway.report.json,.tmp/report/coap-gateway.report.json,.tmp/report/grpc-gateway.report.json,.tmp/report/http-gateway.report.json,.tmp/report/identity-store.report.json,.tmp/report/resource-aggregate.report.json,.tmp/report/resource-directory.report.json + +sonar.go.coverage.reportPaths=./coverage.txt +sonar.coverage.exclusions=examples/**,**/main.go,**/*.pb.go,**/*.pb.gw.go,**/*.js,**/*.py,**/*_test.go diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/client.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client.go new file mode 100644 index 0000000000..d407bd7885 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client.go @@ -0,0 +1,110 @@ +package tcp + +import ( + "crypto/tls" + "fmt" + "net" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/options" + client "github.com/plgd-dev/go-coap/v3/tcp/client" +) + +// A Option sets options such as credentials, keepalive parameters, etc. +type Option interface { + TCPClientApply(cfg *client.Config) +} + +// Dial creates a client connection to the given target. +func Dial(target string, opts ...Option) (*client.Conn, error) { + cfg := client.DefaultConfig + for _, o := range opts { + o.TCPClientApply(&cfg) + } + + var conn net.Conn + var err error + if cfg.TLSCfg != nil { + conn, err = tls.DialWithDialer(cfg.Dialer, cfg.Net, target, cfg.TLSCfg) + } else { + conn, err = cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target) + } + if err != nil { + return nil, err + } + opts = append(opts, options.WithCloseSocket()) + return Client(conn, opts...), nil +} + +// Client creates client over tcp/tcp-tls connection. +func Client(conn net.Conn, opts ...Option) *client.Conn { + cfg := client.DefaultConfig + for _, o := range opts { + o.TCPClientApply(&cfg) + } + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() client.InactivityMonitor { + return inactivity.NewNilMonitor[*client.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + errorsFunc := cfg.Errors + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("tcp: %w", err)) + } + + createBlockWise := func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + return nil + } + if cfg.BlockwiseEnable { + createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + v := cc + return blockwise.New( + v, + cfg.BlockwiseTransferTimeout, + cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + return v.GetObservationRequest(token) + }, + ) + } + } + + l := coapNet.NewConn(conn) + monitor := cfg.CreateInactivityMonitor() + cc := client.NewConn(l, + createBlockWise, + monitor, + &cfg, + ) + + cfg.PeriodicRunner(func(now time.Time) bool { + cc.CheckExpirations(now) + return cc.Context().Err() == nil + }) + + go func() { + err := cc.Run() + if err != nil { + cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err)) + } + }() + + return cc +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/config.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/config.go new file mode 100644 index 0000000000..dd212d989c --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/config.go @@ -0,0 +1,49 @@ +package client + +import ( + "crypto/tls" + "fmt" + "net" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options/config" +) + +var DefaultConfig = func() Config { + opts := Config{ + Common: config.NewCommon[*Conn](), + CreateInactivityMonitor: func() InactivityMonitor { + return inactivity.NewNilMonitor[*Conn]() + }, + Dialer: &net.Dialer{Timeout: time.Second * 3}, + Net: "tcp", + ConnectionCacheSize: 2048, + } + opts.Handler = func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + switch r.Code() { + case codes.POST, codes.PUT, codes.GET, codes.DELETE: + if err := w.SetResponse(codes.NotFound, message.TextPlain, nil); err != nil { + opts.Errors(fmt.Errorf("client handler: cannot set response: %w", err)) + } + } + } + return opts +}() + +type Config struct { + config.Common[*Conn] + CreateInactivityMonitor CreateInactivityMonitorFunc + Net string + Dialer *net.Dialer + TLSCfg *tls.Config + Handler HandlerFunc + ConnectionCacheSize uint16 + DisablePeerTCPSignalMessageCSMs bool + CloseSocket bool + DisableTCPSignalMessageCSM bool +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/conn.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/conn.go new file mode 100644 index 0000000000..2f3da22388 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/conn.go @@ -0,0 +1,370 @@ +package client + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/client" + limitparallelrequests "github.com/plgd-dev/go-coap/v3/net/client/limitParallelRequests" + "github.com/plgd-dev/go-coap/v3/net/observation" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + coapErrors "github.com/plgd-dev/go-coap/v3/pkg/errors" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "go.uber.org/atomic" +) + +type InactivityMonitor interface { + Notify() + CheckInactivity(now time.Time, cc *Conn) +} + +type ( + HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message) + ErrorFunc = func(error) + EventFunc = func() + GetMIDFunc = func() int32 + CreateInactivityMonitorFunc = func() InactivityMonitor +) + +type Notifier interface { + Notify() +} + +// Conn represents a virtual connection to a conceptual endpoint, to perform COAPs commands. +type Conn struct { + *client.Client[*Conn] + session *Session + observationHandler *observation.Handler[*Conn] + processReceivedMessage func(req *pool.Message, cc *Conn, handler HandlerFunc) + tokenHandlerContainer *coapSync.Map[uint64, HandlerFunc] + blockWise *blockwise.BlockWise[*Conn] + blockwiseSZX blockwise.SZX + peerMaxMessageSize atomic.Uint32 + disablePeerTCPSignalMessageCSMs bool + peerBlockWiseTranferEnabled atomic.Bool + + receivedMessageReader *client.ReceivedMessageReader[*Conn] +} + +// NewConn creates connection over session and observation. +func NewConn( + connection *coapNet.Conn, + createBlockWise func(cc *Conn) *blockwise.BlockWise[*Conn], + inactivityMonitor InactivityMonitor, + cfg *Config, +) *Conn { + if cfg.GetToken == nil { + cfg.GetToken = message.GetToken + } + cc := Conn{ + tokenHandlerContainer: coapSync.NewMap[uint64, HandlerFunc](), + blockwiseSZX: cfg.BlockwiseSZX, + disablePeerTCPSignalMessageCSMs: cfg.DisablePeerTCPSignalMessageCSMs, + } + limitParallelRequests := limitparallelrequests.New(cfg.LimitClientParallelRequests, cfg.LimitClientEndpointParallelRequests, cc.do, cc.doObserve) + cc.observationHandler = observation.NewHandler(&cc, cfg.Handler, limitParallelRequests.Do) + cc.Client = client.New(&cc, cc.observationHandler, cfg.GetToken, limitParallelRequests) + cc.blockWise = createBlockWise(&cc) + session := NewSession(cfg.Ctx, + connection, + cfg.MaxMessageSize, + cfg.Errors, + cfg.DisableTCPSignalMessageCSM, + cfg.CloseSocket, + inactivityMonitor, + cfg.ConnectionCacheSize, + cfg.MessagePool, + ) + cc.session = session + if cc.processReceivedMessage == nil { + cc.processReceivedMessage = processReceivedMessage + } + cc.receivedMessageReader = client.NewReceivedMessageReader(&cc, cfg.ReceivedMessageQueueSize) + return &cc +} + +func processReceivedMessage(req *pool.Message, cc *Conn, handler HandlerFunc) { + cc.ProcessReceivedMessageWithHandler(req, handler) +} + +func (cc *Conn) ProcessReceivedMessage(req *pool.Message) { + cc.processReceivedMessage(req, cc, cc.handle) +} + +func (cc *Conn) Session() *Session { + return cc.session +} + +// Close closes connection without wait of ends Run function. +func (cc *Conn) Close() error { + err := cc.session.Close() + if errors.Is(err, net.ErrClosed) { + return nil + } + return err +} + +func (cc *Conn) doInternal(req *pool.Message) (*pool.Message, error) { + token := req.Token() + if token == nil { + return nil, fmt.Errorf("invalid token") + } + respChan := make(chan *pool.Message, 1) + if _, loaded := cc.tokenHandlerContainer.LoadOrStore(token.Hash(), func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + r.Hijack() + select { + case respChan <- r: + default: + } + }); loaded { + return nil, fmt.Errorf("cannot add token handler: %w", coapErrors.ErrKeyAlreadyExists) + } + defer func() { + _, _ = cc.tokenHandlerContainer.LoadAndDelete(token.Hash()) + }() + if err := cc.session.WriteMessage(req); err != nil { + return nil, fmt.Errorf("cannot write request: %w", err) + } + + cc.receivedMessageReader.TryToReplaceLoop() + + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-cc.session.Context().Done(): + return nil, fmt.Errorf("connection was closed: %w", cc.Context().Err()) + case resp := <-respChan: + return resp, nil + } +} + +// Do sends an coap message and returns an coap response. +// +// An error is returned if by failure to speak COAP (such as a network connectivity problem). +// Any status code doesn't cause an error. +// +// Caller is responsible to release request and response. +func (cc *Conn) do(req *pool.Message) (*pool.Message, error) { + if !cc.peerBlockWiseTranferEnabled.Load() || cc.blockWise == nil { + return cc.doInternal(req) + } + resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.maxMessageSize, cc.doInternal) + if err != nil { + return nil, err + } + return resp, nil +} + +func (cc *Conn) writeMessage(req *pool.Message) error { + return cc.session.WriteMessage(req) +} + +// WriteMessage sends an coap message. +func (cc *Conn) WriteMessage(req *pool.Message) error { + if !cc.peerBlockWiseTranferEnabled.Load() || cc.blockWise == nil { + return cc.writeMessage(req) + } + return cc.blockWise.WriteMessage(req, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.writeMessage) +} + +// Context returns the client's context. +// +// If connections was closed context is cancelled. +func (cc *Conn) Context() context.Context { + return cc.session.Context() +} + +// AsyncPing sends ping and receivedPong will be called when pong arrives. It returns cancellation of ping operation. +func (cc *Conn) AsyncPing(receivedPong func()) (func(), error) { + token, err := message.GetToken() + if err != nil { + return nil, fmt.Errorf("cannot get token: %w", err) + } + req := cc.session.messagePool.AcquireMessage(cc.Context()) + req.SetToken(token) + req.SetCode(codes.Ping) + defer cc.ReleaseMessage(req) + + if _, loaded := cc.tokenHandlerContainer.LoadOrStore(token.Hash(), func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + if r.Code() == codes.Pong { + receivedPong() + } + }); loaded { + return nil, fmt.Errorf("cannot add token handler: %w", coapErrors.ErrKeyAlreadyExists) + } + removeTokenHandler := func() { + _, _ = cc.tokenHandlerContainer.LoadAndDelete(token.Hash()) + } + err = cc.session.WriteMessage(req) + if err != nil { + removeTokenHandler() + return nil, fmt.Errorf("cannot write request: %w", err) + } + return removeTokenHandler, nil +} + +// Run reads and process requests from a connection, until the connection is not closed. +func (cc *Conn) Run() (err error) { + return cc.session.Run(cc) +} + +// AddOnClose calls function on close connection event. +func (cc *Conn) AddOnClose(f EventFunc) { + cc.session.AddOnClose(f) +} + +// RemoteAddr gets remote address. +func (cc *Conn) RemoteAddr() net.Addr { + return cc.session.RemoteAddr() +} + +func (cc *Conn) LocalAddr() net.Addr { + return cc.session.LocalAddr() +} + +// Sequence acquires sequence number. +func (cc *Conn) Sequence() uint64 { + return cc.session.Sequence() +} + +// SetContextValue stores the value associated with key to context of connection. +func (cc *Conn) SetContextValue(key interface{}, val interface{}) { + cc.session.SetContextValue(key, val) +} + +// Done signalizes that connection is not more processed. +func (cc *Conn) Done() <-chan struct{} { + return cc.session.Done() +} + +// CheckExpirations checks and remove expired items from caches. +func (cc *Conn) CheckExpirations(now time.Time) { + cc.session.CheckExpirations(now, cc) + if cc.blockWise != nil { + cc.blockWise.CheckExpirations(now) + } +} + +func (cc *Conn) AcquireMessage(ctx context.Context) *pool.Message { + return cc.session.AcquireMessage(ctx) +} + +func (cc *Conn) ReleaseMessage(m *pool.Message) { + cc.session.ReleaseMessage(m) +} + +// NetConn returns the underlying connection that is wrapped by cc. The Conn returned is shared by all invocations of NetConn, so do not modify it. +func (cc *Conn) NetConn() net.Conn { + return cc.session.NetConn() +} + +// DoObserve subscribes for every change with request. +func (cc *Conn) doObserve(req *pool.Message, observeFunc func(req *pool.Message)) (client.Observation, error) { + return cc.observationHandler.NewObservation(req, observeFunc) +} + +func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler HandlerFunc) { + origResp := cc.AcquireMessage(cc.Context()) + origResp.SetToken(req.Token()) + w := responsewriter.New(origResp, cc, req.Options()...) + handler(w, req) + defer cc.ReleaseMessage(w.Message()) + if !req.IsHijacked() { + cc.ReleaseMessage(req) + } + if w.Message().IsModified() { + err := cc.Session().WriteMessage(w.Message()) + if err != nil { + if errC := cc.Close(); errC != nil { + cc.Session().errors(fmt.Errorf("cannot close connection: %w", errC)) + } + cc.Session().errors(fmt.Errorf("cannot write response to %v: %w", cc.RemoteAddr(), err)) + } + } +} + +func (cc *Conn) blockwiseHandle(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + if h, ok := cc.tokenHandlerContainer.Load(r.Token().Hash()); ok { + h(w, r) + return + } + cc.observationHandler.Handle(w, r) +} + +func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + if cc.blockWise != nil && cc.peerBlockWiseTranferEnabled.Load() { + cc.blockWise.Handle(w, r, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.blockwiseHandle) + return + } + if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok { + h(w, r) + return + } + cc.observationHandler.Handle(w, r) +} + +func (cc *Conn) sendPong(token message.Token) error { + req := cc.AcquireMessage(cc.Context()) + defer cc.ReleaseMessage(req) + req.SetCode(codes.Pong) + req.SetToken(token) + return cc.Session().WriteMessage(req) +} + +func (cc *Conn) handleSignals(r *pool.Message) bool { + switch r.Code() { + case codes.CSM: + if cc.disablePeerTCPSignalMessageCSMs { + return true + } + if size, err := r.GetOptionUint32(message.TCPMaxMessageSize); err == nil { + cc.peerMaxMessageSize.Store(size) + } + if r.HasOption(message.TCPBlockWiseTransfer) { + cc.peerBlockWiseTranferEnabled.Store(true) + } + return true + case codes.Ping: + // if r.HasOption(message.TCPCustody) { + // TODO + // } + if err := cc.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) { + cc.Session().errors(fmt.Errorf("cannot handle ping signal: %w", err)) + } + return true + case codes.Release: + // if r.HasOption(message.TCPAlternativeAddress) { + // TODO + // } + return true + case codes.Abort: + // if r.HasOption(message.TCPBadCSMOption) { + // TODO + // } + return true + case codes.Pong: + if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok { + cc.processReceivedMessage(r, cc, h) + } + return true + } + return false +} + +func (cc *Conn) pushToReceivedMessageQueue(r *pool.Message) { + if cc.handleSignals(r) { + return + } + select { + case cc.receivedMessageReader.C() <- r: + case <-cc.Context().Done(): + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/session.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/session.go new file mode 100644 index 0000000000..8332b263d3 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/client/session.go @@ -0,0 +1,272 @@ +package client + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/tcp/coder" + "go.uber.org/atomic" +) + +type Session struct { + // This field needs to be the first in the struct to ensure proper word alignment on 32-bit platforms. + // See: https://golang.org/pkg/sync/atomic/#pkg-note-BUG + sequence atomic.Uint64 + inactivityMonitor InactivityMonitor + errSendCSM error + cancel context.CancelFunc + done chan struct{} + errors ErrorFunc + connection *coapNet.Conn + messagePool *pool.Pool + ctx atomic.Value // TODO: change to atomic.Pointer[context.Context] for go1.19 + maxMessageSize uint32 + private struct { + mutex sync.Mutex + onClose []EventFunc + } + connectionCacheSize uint16 + disableTCPSignalMessageCSM bool + closeSocket bool +} + +func NewSession( + ctx context.Context, + connection *coapNet.Conn, + maxMessageSize uint32, + errors ErrorFunc, + disableTCPSignalMessageCSM bool, + closeSocket bool, + inactivityMonitor InactivityMonitor, + connectionCacheSize uint16, + messagePool *pool.Pool, +) *Session { + ctx, cancel := context.WithCancel(ctx) + if errors == nil { + errors = func(error) { + // default no-op + } + } + if inactivityMonitor == nil { + inactivityMonitor = inactivity.NewNilMonitor[*Conn]() + } + + s := &Session{ + cancel: cancel, + connection: connection, + maxMessageSize: maxMessageSize, + errors: errors, + disableTCPSignalMessageCSM: disableTCPSignalMessageCSM, + closeSocket: closeSocket, + inactivityMonitor: inactivityMonitor, + done: make(chan struct{}), + connectionCacheSize: connectionCacheSize, + messagePool: messagePool, + } + s.ctx.Store(&ctx) + + if !disableTCPSignalMessageCSM { + err := s.sendCSM() + if err != nil { + s.errSendCSM = fmt.Errorf("cannot send CSM: %w", err) + } + } + + return s +} + +// SetContextValue stores the value associated with key to context of connection. +func (s *Session) SetContextValue(key interface{}, val interface{}) { + ctx := context.WithValue(s.Context(), key, val) + s.ctx.Store(&ctx) +} + +// Done signalizes that connection is not more processed. +func (s *Session) Done() <-chan struct{} { + return s.done +} + +func (s *Session) AddOnClose(f EventFunc) { + s.private.mutex.Lock() + defer s.private.mutex.Unlock() + s.private.onClose = append(s.private.onClose, f) +} + +func (s *Session) popOnClose() []EventFunc { + s.private.mutex.Lock() + defer s.private.mutex.Unlock() + tmp := s.private.onClose + s.private.onClose = nil + return tmp +} + +func (s *Session) shutdown() { + defer close(s.done) + for _, f := range s.popOnClose() { + f() + } +} + +func (s *Session) Close() error { + s.cancel() + if s.closeSocket { + return s.connection.Close() + } + return nil +} + +func (s *Session) Sequence() uint64 { + return s.sequence.Inc() +} + +func (s *Session) Context() context.Context { + return *s.ctx.Load().(*context.Context) //nolint:forcetypeassert +} + +func seekBufferToNextMessage(buffer *bytes.Buffer, msgSize int) *bytes.Buffer { + if msgSize == buffer.Len() { + // buffer is empty so reset it + buffer.Reset() + return buffer + } + // rewind to next message + trimmed := 0 + for trimmed != msgSize { + b := make([]byte, 4096) + max := 4096 + if msgSize-trimmed < max { + max = msgSize - trimmed + } + v, _ := buffer.Read(b[:max]) + trimmed += v + } + return buffer +} + +func (s *Session) processBuffer(buffer *bytes.Buffer, cc *Conn) error { + for buffer.Len() > 0 { + var header coder.MessageHeader + _, err := coder.DefaultCoder.DecodeHeader(buffer.Bytes(), &header) + if errors.Is(err, message.ErrShortRead) { + return nil + } + if header.MessageLength > s.maxMessageSize { + return fmt.Errorf("max message size(%v) was exceeded %v", s.maxMessageSize, header.MessageLength) + } + if uint32(buffer.Len()) < header.MessageLength { + return nil + } + req := s.messagePool.AcquireMessage(s.Context()) + read, err := req.UnmarshalWithDecoder(coder.DefaultCoder, buffer.Bytes()[:header.MessageLength]) + if err != nil { + s.messagePool.ReleaseMessage(req) + return fmt.Errorf("cannot unmarshal with header: %w", err) + } + buffer = seekBufferToNextMessage(buffer, read) + req.SetSequence(s.Sequence()) + s.inactivityMonitor.Notify() + cc.pushToReceivedMessageQueue(req) + } + return nil +} + +func (s *Session) WriteMessage(req *pool.Message) error { + data, err := req.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + return fmt.Errorf("cannot marshal: %w", err) + } + err = s.connection.WriteWithContext(req.Context(), data) + if err != nil { + return fmt.Errorf("cannot write to connection: %w", err) + } + return err +} + +func (s *Session) sendCSM() error { + token, err := message.GetToken() + if err != nil { + return fmt.Errorf("cannot get token: %w", err) + } + req := s.messagePool.AcquireMessage(s.Context()) + defer s.messagePool.ReleaseMessage(req) + req.SetCode(codes.CSM) + req.SetToken(token) + return s.WriteMessage(req) +} + +func shrinkBufferIfNecessary(buffer *bytes.Buffer, maxCap uint16) *bytes.Buffer { + if buffer.Len() == 0 && buffer.Cap() > int(maxCap) { + buffer = bytes.NewBuffer(make([]byte, 0, maxCap)) + } + return buffer +} + +// Run reads and process requests from a connection, until the connection is not closed. +func (s *Session) Run(cc *Conn) (err error) { + defer func() { + err1 := s.Close() + if err == nil { + err = err1 + } + s.shutdown() + }() + if s.errSendCSM != nil { + return s.errSendCSM + } + buffer := bytes.NewBuffer(make([]byte, 0, s.connectionCacheSize)) + readBuf := make([]byte, s.connectionCacheSize) + for { + err = s.processBuffer(buffer, cc) + if err != nil { + return err + } + buffer = shrinkBufferIfNecessary(buffer, s.connectionCacheSize) + readLen, err := s.connection.ReadWithContext(s.Context(), readBuf) + if err != nil { + if coapNet.IsConnectionBrokenError(err) { // other side closed the connection, ignore the error and return + return nil + } + return fmt.Errorf("cannot read from connection: %w", err) + } + if readLen > 0 { + buffer.Write(readBuf[:readLen]) + } + } +} + +// CheckExpirations checks and remove expired items from caches. +func (s *Session) CheckExpirations(now time.Time, cc *Conn) { + s.inactivityMonitor.CheckInactivity(now, cc) +} + +func (s *Session) AcquireMessage(ctx context.Context) *pool.Message { + return s.messagePool.AcquireMessage(ctx) +} + +func (s *Session) ReleaseMessage(m *pool.Message) { + s.messagePool.ReleaseMessage(m) +} + +// RemoteAddr gets remote address. +func (s *Session) RemoteAddr() net.Addr { + return s.connection.RemoteAddr() +} + +func (s *Session) LocalAddr() net.Addr { + return s.connection.LocalAddr() +} + +// NetConn returns the underlying connection that is wrapped by s. The Conn returned is shared by all invocations of NetConn, so do not modify it. +func (s *Session) NetConn() net.Conn { + return s.connection.NetConn() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/coder/coder.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/coder/coder.go new file mode 100644 index 0000000000..9979370cad --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/coder/coder.go @@ -0,0 +1,254 @@ +package coder + +import ( + "encoding/binary" + "errors" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +var DefaultCoder = new(Coder) + +const ( + MessageLength13Base = 13 + MessageLength14Base = 269 + MessageLength15Base = 65805 + messageMaxLen = 0x7fff0000 // Large number that works in 32-bit builds +) + +type Coder struct{} + +type MessageHeader struct { + Token []byte + Length uint32 + MessageLength uint32 + Code codes.Code +} + +func (c *Coder) Size(m message.Message) (int, error) { + size, err := c.Encode(m, nil) + if errors.Is(err, message.ErrTooSmall) { + err = nil + } + return size, err +} + +func getHeader(messageLength int) (uint8, []byte) { + if messageLength < MessageLength13Base { + return uint8(messageLength), nil + } + if messageLength < MessageLength14Base { + extLen := messageLength - MessageLength13Base + extLenBytes := []byte{uint8(extLen)} + return 13, extLenBytes + } + if messageLength < MessageLength15Base { + extLen := messageLength - MessageLength14Base + extLenBytes := make([]byte, 2) + binary.BigEndian.PutUint16(extLenBytes, uint16(extLen)) + return 14, extLenBytes + } + if messageLength < messageMaxLen { + extLen := messageLength - MessageLength15Base + extLenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(extLenBytes, uint32(extLen)) + return 15, extLenBytes + } + return 0, nil +} + +func (c *Coder) Encode(m message.Message, buf []byte) (int, error) { + /* + A CoAP Message message lomessage.OKs like: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Len | TKL | Extended Length ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Code | TKL bytes ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Options (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |1 1 1 1 1 1 1 1| Payload (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + The size of the Extended Length field is inferred from the value of the + Len field as follows: + + | Len value | Extended Length size | Total length | + +------------+-----------------------+---------------------------+ + | 0-12 | 0 | Len | + | 13 | 1 | Extended Length + 13 | + | 14 | 2 | Extended Length + 269 | + | 15 | 4 | Extended Length + 65805 | + */ + + if len(m.Token) > message.MaxTokenSize { + return -1, message.ErrInvalidTokenLen + } + + payloadLen := len(m.Payload) + if payloadLen > 0 { + // for separator 0xff + payloadLen++ + } + optionsLen, err := m.Options.Marshal(nil) + if !errors.Is(err, message.ErrTooSmall) { + return -1, err + } + bufLen := payloadLen + optionsLen + lenNib, extLenBytes := getHeader(bufLen) + + var hdr [1 + 4 + message.MaxTokenSize + 1]byte + hdrLen := 1 + len(extLenBytes) + len(m.Token) + 1 + hdrOff := 0 + + copyToHdr := func(offset int, data []byte) int { + if len(data) > 0 { + copy(hdr[hdrOff:hdrOff+len(data)], data) + offset += len(data) + } + return offset + } + + // Length and TKL nibbles. + hdr[hdrOff] = uint8(0xf&len(m.Token)) | (lenNib << 4) + hdrOff++ + + // Extended length, if present. + hdrOff = copyToHdr(hdrOff, extLenBytes) + + // Code. + hdr[hdrOff] = byte(m.Code) + hdrOff++ + + // Token. + copyToHdr(hdrOff, m.Token) + + bufLen += hdrLen + if len(buf) < bufLen { + return bufLen, message.ErrTooSmall + } + + copy(buf, hdr[:hdrLen]) + optionsLen, err = m.Options.Marshal(buf[hdrLen:]) + switch { + case err == nil: + case errors.Is(err, message.ErrTooSmall): + return bufLen, err + default: + return -1, err + } + if len(m.Payload) > 0 { + copy(buf[hdrLen+optionsLen:], []byte{0xff}) + copy(buf[hdrLen+optionsLen+1:], m.Payload) + } + + return bufLen, nil +} + +func (c *Coder) DecodeHeader(data []byte, h *MessageHeader) (int, error) { + hdrOff := uint32(0) + if len(data) == 0 { + return -1, message.ErrShortRead + } + + firstByte := data[0] + data = data[1:] + hdrOff++ + + lenNib := (firstByte & 0xf0) >> 4 + tkl := firstByte & 0x0f + + var opLen int + switch { + case lenNib < MessageLength13Base: + opLen = int(lenNib) + case lenNib == 13: + if len(data) < 1 { + return -1, message.ErrShortRead + } + extLen := data[0] + data = data[1:] + hdrOff++ + opLen = MessageLength13Base + int(extLen) + case lenNib == 14: + if len(data) < 2 { + return -1, message.ErrShortRead + } + extLen := binary.BigEndian.Uint16(data) + data = data[2:] + hdrOff += 2 + opLen = MessageLength14Base + int(extLen) + case lenNib == 15: + if len(data) < 4 { + return -1, message.ErrShortRead + } + extLen := binary.BigEndian.Uint32(data) + data = data[4:] + hdrOff += 4 + opLen = MessageLength15Base + int(extLen) + } + + h.MessageLength = hdrOff + 1 + uint32(tkl) + uint32(opLen) + if len(data) < 1 { + return -1, message.ErrShortRead + } + h.Code = codes.Code(data[0]) + data = data[1:] + hdrOff++ + if len(data) < int(tkl) { + return -1, message.ErrShortRead + } + if tkl > 0 { + h.Token = data[:tkl] + } + hdrOff += uint32(tkl) + h.Length = hdrOff + return int(h.Length), nil +} + +func (c *Coder) DecodeWithHeader(data []byte, header MessageHeader, m *message.Message) (int, error) { + optionDefs := message.CoapOptionDefs + processed := header.Length + switch header.Code { + case codes.CSM: + optionDefs = message.TCPSignalCSMOptionDefs + case codes.Ping, codes.Pong: + optionDefs = message.TCPSignalPingPongOptionDefs + case codes.Release: + optionDefs = message.TCPSignalReleaseOptionDefs + case codes.Abort: + optionDefs = message.TCPSignalAbortOptionDefs + } + + proc, err := m.Options.Unmarshal(data, optionDefs) + if err != nil { + return -1, err + } + data = data[proc:] + processed += uint32(proc) + + if len(data) > 0 { + m.Payload = data + } + processed += uint32(len(data)) + m.Code = header.Code + m.Token = header.Token + + return int(processed), nil +} + +func (c *Coder) Decode(data []byte, m *message.Message) (int, error) { + var header MessageHeader + _, err := c.DecodeHeader(data, &header) + if err != nil { + return -1, err + } + if uint32(len(data)) < header.MessageLength { + return -1, message.ErrShortRead + } + return c.DecodeWithHeader(data[header.Length:], header, m) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/coder/error.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/coder/error.go new file mode 100644 index 0000000000..ac7ae99aad --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/coder/error.go @@ -0,0 +1,8 @@ +package coder + +import "errors" + +var ( + ErrMessageTruncated = errors.New("message is truncated") + ErrMessageInvalidVersion = errors.New("message has invalid version") +) diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/server.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/server.go new file mode 100644 index 0000000000..d5ebbffcbe --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/server.go @@ -0,0 +1,9 @@ +package tcp + +import ( + "github.com/plgd-dev/go-coap/v3/tcp/server" +) + +func NewServer(opt ...server.Option) *server.Server { + return server.New(opt...) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/server/config.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/server/config.go new file mode 100644 index 0000000000..6147ec36e9 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/server/config.go @@ -0,0 +1,62 @@ +package server + +import ( + "fmt" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options/config" + "github.com/plgd-dev/go-coap/v3/tcp/client" +) + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as COAP handlers. +type HandlerFunc = func(*responsewriter.ResponseWriter[*client.Conn], *pool.Message) + +type ErrorFunc = func(error) + +type GoPoolFunc = func(func()) error + +// OnNewConnFunc is the callback for new connections. +type OnNewConnFunc = func(cc *client.Conn) + +var DefaultConfig = func() Config { + opts := Config{ + Common: config.NewCommon[*client.Conn](), + CreateInactivityMonitor: func() client.InactivityMonitor { + maxRetries := uint32(2) + timeout := time.Second * 16 + onInactive := func(cc *client.Conn) { + _ = cc.Close() + } + keepalive := inactivity.NewKeepAlive(maxRetries, onInactive, func(cc *client.Conn, receivePong func()) (func(), error) { + return cc.AsyncPing(receivePong) + }) + return inactivity.New(timeout/time.Duration(maxRetries+1), keepalive.OnInactive) + }, + OnNewConn: func(cc *client.Conn) { + // do nothing by default + }, + ConnectionCacheSize: 2 * 1024, + } + opts.Handler = func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { + if err := w.SetResponse(codes.NotFound, message.TextPlain, nil); err != nil { + opts.Errors(fmt.Errorf("server handler: cannot set response: %w", err)) + } + } + return opts +}() + +type Config struct { + config.Common[*client.Conn] + CreateInactivityMonitor client.CreateInactivityMonitorFunc + Handler HandlerFunc + OnNewConn OnNewConnFunc + ConnectionCacheSize uint16 + DisablePeerTCPSignalMessageCSMs bool + DisableTCPSignalMessageCSM bool +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/tcp/server/server.go b/vendor/github.com/plgd-dev/go-coap/v3/tcp/server/server.go new file mode 100644 index 0000000000..a869e29847 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/tcp/server/server.go @@ -0,0 +1,223 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/pkg/connections" + "github.com/plgd-dev/go-coap/v3/tcp/client" +) + +// A Option sets options such as credentials, codec and keepalive parameters, etc. +type Option interface { + TCPServerApply(cfg *Config) +} + +// Listener defined used by coap +type Listener interface { + Close() error + AcceptWithContext(ctx context.Context) (net.Conn, error) +} + +type Server struct { + listenMutex sync.Mutex + listen Listener + ctx context.Context + cancel context.CancelFunc + cfg *Config +} + +func New(opt ...Option) *Server { + cfg := DefaultConfig + for _, o := range opt { + o.TCPServerApply(&cfg) + } + + ctx, cancel := context.WithCancel(cfg.Ctx) + + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() client.InactivityMonitor { + return inactivity.NewNilMonitor[*client.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + if cfg.GetToken == nil { + cfg.GetToken = message.GetToken + } + errorsFunc := cfg.Errors + // assign updated func to opts.errors so opts.handler also uses the updated error handler + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("tcp: %w", err)) + } + + return &Server{ + ctx: ctx, + cancel: cancel, + cfg: &cfg, + } +} + +func (s *Server) checkAndSetListener(l Listener) error { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + if s.listen != nil { + return fmt.Errorf("server already serves listener") + } + s.listen = l + return nil +} + +func (s *Server) popListener() Listener { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + l := s.listen + s.listen = nil + return l +} + +func (s *Server) checkAcceptError(err error) bool { + if err == nil { + return true + } + switch { + case errors.Is(err, coapNet.ErrListenerIsClosed): + s.Stop() + return false + case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled): + select { + case <-s.ctx.Done(): + default: + s.cfg.Errors(fmt.Errorf("cannot accept connection: %w", err)) + return true + } + return false + default: + return true + } +} + +func (s *Server) serveConnection(connections *connections.Connections, rw net.Conn) { + var cc *client.Conn + monitor := s.cfg.CreateInactivityMonitor() + cc = s.createConn(coapNet.NewConn(rw), monitor) + if s.cfg.OnNewConn != nil { + s.cfg.OnNewConn(cc) + } + connections.Store(cc) + defer connections.Delete(cc) + + if err := cc.Run(); err != nil { + s.cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err)) + } +} + +func (s *Server) Serve(l Listener) error { + if s.cfg.BlockwiseSZX > blockwise.SZXBERT { + return fmt.Errorf("invalid blockwiseSZX") + } + + err := s.checkAndSetListener(l) + if err != nil { + return err + } + defer func() { + s.Stop() + }() + var wg sync.WaitGroup + defer wg.Wait() + + connections := connections.New() + s.cfg.PeriodicRunner(func(now time.Time) bool { + connections.CheckExpirations(now) + return s.ctx.Err() == nil + }) + defer connections.Close() + + for { + rw, err := l.AcceptWithContext(s.ctx) + if ok := s.checkAcceptError(err); !ok { + return nil + } + if rw == nil { + continue + } + wg.Add(1) + go func() { + defer wg.Done() + s.serveConnection(connections, rw) + }() + } +} + +// Stop stops server without wait of ends Serve function. +func (s *Server) Stop() { + s.cancel() + l := s.popListener() + if l == nil { + return + } + if err := l.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err)) + } +} + +func (s *Server) createConn(connection *coapNet.Conn, monitor client.InactivityMonitor) *client.Conn { + createBlockWise := func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + return nil + } + if s.cfg.BlockwiseEnable { + createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + return blockwise.New( + cc, + s.cfg.BlockwiseTransferTimeout, + s.cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + return nil, false + }, + ) + } + } + cfg := client.DefaultConfig + cfg.Ctx = s.ctx + cfg.Handler = s.cfg.Handler + cfg.MaxMessageSize = s.cfg.MaxMessageSize + cfg.Errors = s.cfg.Errors + cfg.BlockwiseSZX = s.cfg.BlockwiseSZX + cfg.DisablePeerTCPSignalMessageCSMs = s.cfg.DisablePeerTCPSignalMessageCSMs + cfg.DisableTCPSignalMessageCSM = s.cfg.DisableTCPSignalMessageCSM + cfg.CloseSocket = true + cfg.ConnectionCacheSize = s.cfg.ConnectionCacheSize + cfg.MessagePool = s.cfg.MessagePool + cfg.GetToken = s.cfg.GetToken + cfg.ProcessReceivedMessage = s.cfg.ProcessReceivedMessage + cfg.ReceivedMessageQueueSize = s.cfg.ReceivedMessageQueueSize + cc := client.NewConn( + connection, + createBlockWise, + monitor, + &cfg, + ) + + return cc +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/client.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/client.go new file mode 100644 index 0000000000..26e7a011c0 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/client.go @@ -0,0 +1,116 @@ +package udp + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/options" + "github.com/plgd-dev/go-coap/v3/udp/client" + "github.com/plgd-dev/go-coap/v3/udp/server" +) + +// A Option sets options such as credentials, keepalive parameters, etc. +type Option interface { + UDPClientApply(cfg *client.Config) +} + +// Dial creates a client connection to the given target. +func Dial(target string, opts ...Option) (*client.Conn, error) { + cfg := client.DefaultConfig + for _, o := range opts { + o.UDPClientApply(&cfg) + } + c, err := cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target) + if err != nil { + return nil, err + } + conn, ok := c.(*net.UDPConn) + if !ok { + return nil, fmt.Errorf("unsupported connection type: %T", c) + } + opts = append(opts, options.WithCloseSocket()) + return Client(conn, opts...), nil +} + +// Client creates client over udp connection. +func Client(conn *net.UDPConn, opts ...Option) *client.Conn { + cfg := client.DefaultConfig + for _, o := range opts { + o.UDPClientApply(&cfg) + } + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() client.InactivityMonitor { + return inactivity.NewNilMonitor[*client.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + + errorsFunc := cfg.Errors + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("udp: %v: %w", conn.RemoteAddr(), err)) + } + addr, _ := conn.RemoteAddr().(*net.UDPAddr) + createBlockWise := func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + return nil + } + if cfg.BlockwiseEnable { + createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + v := cc + return blockwise.New( + v, + cfg.BlockwiseTransferTimeout, + cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + return v.GetObservationRequest(token) + }, + ) + } + } + + monitor := cfg.CreateInactivityMonitor() + l := coapNet.NewUDPConn(cfg.Net, conn, coapNet.WithErrors(cfg.Errors)) + session := server.NewSession(cfg.Ctx, + context.Background(), + l, + addr, + cfg.MaxMessageSize, + cfg.MTU, + cfg.CloseSocket, + ) + cc := client.NewConn(session, + createBlockWise, + monitor, + &cfg, + ) + cfg.PeriodicRunner(func(now time.Time) bool { + cc.CheckExpirations(now) + return cc.Context().Err() == nil + }) + + go func() { + err := cc.Run() + if err != nil { + cfg.Errors(err) + } + }() + + return cc +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/client/config.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/client/config.go new file mode 100644 index 0000000000..2b719f88c4 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/client/config.go @@ -0,0 +1,55 @@ +package client + +import ( + "fmt" + "net" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options/config" +) + +const DefaultMTU = 1472 + +var DefaultConfig = func() Config { + opts := Config{ + Common: config.NewCommon[*Conn](), + CreateInactivityMonitor: func() InactivityMonitor { + return inactivity.NewNilMonitor[*Conn]() + }, + Dialer: &net.Dialer{Timeout: time.Second * 3}, + Net: "udp", + TransmissionNStart: 1, + TransmissionAcknowledgeTimeout: time.Second * 2, + TransmissionMaxRetransmit: 4, + GetMID: message.GetMID, + MTU: DefaultMTU, + } + opts.Handler = func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + switch r.Code() { + case codes.POST, codes.PUT, codes.GET, codes.DELETE: + if err := w.SetResponse(codes.NotFound, message.TextPlain, nil); err != nil { + opts.Errors(fmt.Errorf("udp client: cannot set response: %w", err)) + } + } + } + return opts +}() + +type Config struct { + config.Common[*Conn] + CreateInactivityMonitor CreateInactivityMonitorFunc + Net string + GetMID GetMIDFunc + Handler HandlerFunc + Dialer *net.Dialer + TransmissionNStart uint32 + TransmissionAcknowledgeTimeout time.Duration + TransmissionMaxRetransmit uint32 + CloseSocket bool + MTU uint16 +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/client/conn.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/client/conn.go new file mode 100644 index 0000000000..8e20d8714f --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/client/conn.go @@ -0,0 +1,888 @@ +package client + +import ( + "context" + "errors" + "fmt" + "math" + "net" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/client" + limitparallelrequests "github.com/plgd-dev/go-coap/v3/net/client/limitParallelRequests" + "github.com/plgd-dev/go-coap/v3/net/observation" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options/config" + "github.com/plgd-dev/go-coap/v3/pkg/cache" + coapErrors "github.com/plgd-dev/go-coap/v3/pkg/errors" + "github.com/plgd-dev/go-coap/v3/pkg/fn" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "github.com/plgd-dev/go-coap/v3/udp/coder" + "go.uber.org/atomic" + "golang.org/x/sync/semaphore" +) + +// https://datatracker.ietf.org/doc/html/rfc7252#section-4.8.2 +const ExchangeLifetime = 247 * time.Second + +type ( + HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message) + ErrorFunc = func(error) + EventFunc = func() + GetMIDFunc = func() int32 + CreateInactivityMonitorFunc = func() InactivityMonitor +) + +type InactivityMonitor interface { + Notify() + CheckInactivity(now time.Time, cc *Conn) +} + +type Session interface { + Context() context.Context + Close() error + MaxMessageSize() uint32 + RemoteAddr() net.Addr + LocalAddr() net.Addr + // NetConn returns the underlying connection that is wrapped by Session. The Conn returned is shared by all invocations of NetConn, so do not modify it. + NetConn() net.Conn + WriteMessage(req *pool.Message) error + // WriteMulticast sends multicast to the remote multicast address. + // By default it is sent over all network interfaces and all compatible source IP addresses with hop limit 1. + // Via opts you can specify the network interface, source IP address, and hop limit. + WriteMulticastMessage(req *pool.Message, address *net.UDPAddr, opts ...coapNet.MulticastOption) error + Run(cc *Conn) error + AddOnClose(f EventFunc) + SetContextValue(key interface{}, val interface{}) + Done() <-chan struct{} +} + +type RequestsMap = coapSync.Map[uint64, *pool.Message] + +const ( + errFmtWriteRequest = "cannot write request: %w" + errFmtWriteResponse = "cannot write response: %w" +) + +type midElement struct { + handler HandlerFunc + start time.Time + deadline time.Time + retransmit atomic.Int32 + + private struct { + sync.Mutex + msg *pool.Message + } +} + +func (m *midElement) ReleaseMessage(cc *Conn) { + m.private.Lock() + defer m.private.Unlock() + if m.private.msg != nil { + cc.ReleaseMessage(m.private.msg) + m.private.msg = nil + } +} + +func (m *midElement) IsExpired(now time.Time, maxRetransmit int32) bool { + if !m.deadline.IsZero() && now.After(m.deadline) { + // remove element if deadline is exceeded + return true + } + retransmit := m.retransmit.Load() + return retransmit >= maxRetransmit +} + +func (m *midElement) Retransmit(now time.Time, acknowledgeTimeout time.Duration) bool { + if now.After(m.start.Add(acknowledgeTimeout * time.Duration(m.retransmit.Load()+1))) { + m.retransmit.Inc() + // retransmit + return true + } + // wait for next retransmit + return false +} + +func (m *midElement) GetMessage(cc *Conn) (*pool.Message, bool, error) { + m.private.Lock() + defer m.private.Unlock() + if m.private.msg == nil { + return nil, false, nil + } + msg := cc.AcquireMessage(m.private.msg.Context()) + if err := m.private.msg.Clone(msg); err != nil { + cc.ReleaseMessage(msg) + return nil, false, err + } + return msg, true, nil +} + +// Conn represents a virtual connection to a conceptual endpoint, to perform COAPs commands. +type Conn struct { + // This field needs to be the first in the struct to ensure proper word alignment on 32-bit platforms. + // See: https://golang.org/pkg/sync/atomic/#pkg-note-BUG + sequence atomic.Uint64 + + session Session + *client.Client[*Conn] + inactivityMonitor InactivityMonitor + + blockWise *blockwise.BlockWise[*Conn] + observationHandler *observation.Handler[*Conn] + transmission *Transmission + messagePool *pool.Pool + + processReceivedMessage config.ProcessReceivedMessageFunc[*Conn] + errors ErrorFunc + responseMsgCache *cache.Cache[string, []byte] + msgIDMutex *MutexMap + + tokenHandlerContainer *coapSync.Map[uint64, HandlerFunc] + midHandlerContainer *coapSync.Map[int32, *midElement] + msgID atomic.Uint32 + blockwiseSZX blockwise.SZX + + /* + An outstanding interaction is either a CON for which an ACK has not + yet been received but is still expected (message layer) or a request + for which neither a response nor an Acknowledgment message has yet + been received but is still expected (which may both occur at the same + time, counting as one outstanding interaction). + */ + numOutstandingInteraction *semaphore.Weighted + receivedMessageReader *client.ReceivedMessageReader[*Conn] +} + +// Transmission is a threadsafe container for transmission related parameters +type Transmission struct { + nStart *atomic.Uint32 + acknowledgeTimeout *atomic.Duration + maxRetransmit *atomic.Int32 +} + +// SetTransmissionNStart changing the nStart value will only effect requests queued after the change. The requests waiting here already before the change will get unblocked when enough weight has been released. +func (t *Transmission) SetTransmissionNStart(d uint32) { + t.nStart.Store(d) +} + +func (t *Transmission) SetTransmissionAcknowledgeTimeout(d time.Duration) { + t.acknowledgeTimeout.Store(d) +} + +func (t *Transmission) SetTransmissionMaxRetransmit(d int32) { + t.maxRetransmit.Store(d) +} + +func (cc *Conn) Transmission() *Transmission { + return cc.transmission +} + +// NewConn creates connection over session and observation. +func NewConn( + session Session, + createBlockWise func(cc *Conn) *blockwise.BlockWise[*Conn], + inactivityMonitor InactivityMonitor, + cfg *Config, +) *Conn { + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + if cfg.GetMID == nil { + cfg.GetMID = message.GetMID + } + if cfg.GetToken == nil { + cfg.GetToken = message.GetToken + } + if cfg.ReceivedMessageQueueSize < 0 { + cfg.ReceivedMessageQueueSize = 0 + } + + cc := Conn{ + session: session, + transmission: &Transmission{ + atomic.NewUint32(cfg.TransmissionNStart), + atomic.NewDuration(cfg.TransmissionAcknowledgeTimeout), + atomic.NewInt32(int32(cfg.TransmissionMaxRetransmit)), + }, + blockwiseSZX: cfg.BlockwiseSZX, + + tokenHandlerContainer: coapSync.NewMap[uint64, HandlerFunc](), + midHandlerContainer: coapSync.NewMap[int32, *midElement](), + processReceivedMessage: cfg.ProcessReceivedMessage, + errors: cfg.Errors, + msgIDMutex: NewMutexMap(), + responseMsgCache: cache.NewCache[string, []byte](), + inactivityMonitor: inactivityMonitor, + messagePool: cfg.MessagePool, + numOutstandingInteraction: semaphore.NewWeighted(math.MaxInt64), + } + cc.msgID.Store(uint32(cfg.GetMID() - 0xffff/2)) + cc.blockWise = createBlockWise(&cc) + limitParallelRequests := limitparallelrequests.New(cfg.LimitClientParallelRequests, cfg.LimitClientEndpointParallelRequests, cc.do, cc.doObserve) + cc.observationHandler = observation.NewHandler(&cc, cfg.Handler, limitParallelRequests.Do) + cc.Client = client.New(&cc, cc.observationHandler, cfg.GetToken, limitParallelRequests) + if cc.processReceivedMessage == nil { + cc.processReceivedMessage = processReceivedMessage + } + cc.receivedMessageReader = client.NewReceivedMessageReader(&cc, cfg.ReceivedMessageQueueSize) + return &cc +} + +func processReceivedMessage(req *pool.Message, cc *Conn, handler config.HandlerFunc[*Conn]) { + cc.ProcessReceivedMessageWithHandler(req, handler) +} + +func (cc *Conn) ProcessReceivedMessage(req *pool.Message) { + cc.processReceivedMessage(req, cc, cc.handleReq) +} + +func (cc *Conn) Session() Session { + return cc.session +} + +func (cc *Conn) GetMessageID() int32 { + // To prevent collisions during reconnections, it is important to always increment the global counter. + // For example, if a connection (cc) is established and later closed due to inactivity, a new cc may + // be created shortly after. However, if the new cc is initialized with the same message ID as the + // previous one, the receiver may mistakenly treat the incoming message as a duplicate and discard it. + // Hence, by incrementing the global counter, we can ensure unique message IDs and avoid such issues. + message.GetMID() + return int32(uint16(cc.msgID.Inc())) +} + +// Close closes connection without waiting for the end of the Run function. +func (cc *Conn) Close() error { + err := cc.session.Close() + if errors.Is(err, net.ErrClosed) { + return nil + } + return err +} + +func (cc *Conn) doInternal(req *pool.Message) (*pool.Message, error) { + token := req.Token() + if token == nil { + return nil, fmt.Errorf("invalid token") + } + + respChan := make(chan *pool.Message, 1) + if _, loaded := cc.tokenHandlerContainer.LoadOrStore(token.Hash(), func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + r.Hijack() + select { + case respChan <- r: + default: + } + }); loaded { + return nil, fmt.Errorf("cannot add token(%v) handler: %w", token, coapErrors.ErrKeyAlreadyExists) + } + defer func() { + _, _ = cc.tokenHandlerContainer.LoadAndDelete(token.Hash()) + }() + err := cc.writeMessage(req) + if err != nil { + return nil, fmt.Errorf(errFmtWriteRequest, err) + } + cc.receivedMessageReader.TryToReplaceLoop() + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-cc.Context().Done(): + return nil, fmt.Errorf("connection was closed: %w", cc.session.Context().Err()) + case resp := <-respChan: + return resp, nil + } +} + +// Do sends an coap message and returns an coap response. +// +// An error is returned if by failure to speak COAP (such as a network connectivity problem). +// Any status code doesn't cause an error. +// +// Caller is responsible to release request and response. +func (cc *Conn) do(req *pool.Message) (*pool.Message, error) { + if cc.blockWise == nil { + return cc.doInternal(req) + } + resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(bwReq *pool.Message) (*pool.Message, error) { + if bwReq.Options().HasOption(message.Block1) || bwReq.Options().HasOption(message.Block2) { + bwReq.SetMessageID(cc.GetMessageID()) + } + return cc.doInternal(bwReq) + }) + if err != nil { + return nil, err + } + return resp, nil +} + +// DoObserve subscribes for every change with request. +func (cc *Conn) doObserve(req *pool.Message, observeFunc func(req *pool.Message)) (client.Observation, error) { + return cc.observationHandler.NewObservation(req, observeFunc) +} + +func (cc *Conn) releaseOutstandingInteraction() { + cc.numOutstandingInteraction.Release(1) +} + +func (cc *Conn) acquireOutstandingInteraction(ctx context.Context) error { + nStart := cc.Transmission().nStart.Load() + if nStart == 0 { + return fmt.Errorf("invalid NStart value %v", nStart) + } + n := math.MaxInt64 - int64(cc.Transmission().nStart.Load()) + 1 + err := cc.numOutstandingInteraction.Acquire(ctx, n) + if err != nil { + return err + } + cc.numOutstandingInteraction.Release(n - 1) + return nil +} + +func (cc *Conn) waitForAcknowledge(req *pool.Message, waitForResponseChan chan struct{}) error { + cc.receivedMessageReader.TryToReplaceLoop() + select { + case <-waitForResponseChan: + return nil + case <-req.Context().Done(): + return req.Context().Err() + case <-cc.Context().Done(): + return fmt.Errorf("connection was closed: %w", cc.Context().Err()) + } +} + +func (cc *Conn) prepareWriteMessage(req *pool.Message, handler HandlerFunc) (func(), error) { + var closeFns fn.FuncList + + // Only confirmable messages ever match an message ID + switch req.Type() { + case message.Confirmable: + msg := cc.AcquireMessage(req.Context()) + if err := req.Clone(msg); err != nil { + cc.ReleaseMessage(msg) + return nil, fmt.Errorf("cannot clone message: %w", err) + } + if req.Code() >= codes.GET && req.Code() <= codes.DELETE { + if err := cc.acquireOutstandingInteraction(req.Context()); err != nil { + return nil, err + } + closeFns = append(closeFns, func() { + cc.releaseOutstandingInteraction() + }) + } + deadline, _ := req.Context().Deadline() + if _, loaded := cc.midHandlerContainer.LoadOrStore(req.MessageID(), &midElement{ + handler: handler, + start: time.Now(), + deadline: deadline, + private: struct { + sync.Mutex + msg *pool.Message + }{msg: msg}, + }); loaded { + closeFns.Execute() + return nil, fmt.Errorf("cannot insert mid(%v) handler: %w", req.MessageID(), coapErrors.ErrKeyAlreadyExists) + } + closeFns = append(closeFns, func() { + _, _ = cc.midHandlerContainer.LoadAndDelete(req.MessageID()) + }) + case message.NonConfirmable: + /* TODO need to acquireOutstandingInteraction + if req.Code() >= codes.GET && req.Code() <= codes.DELETE { + } + */ + } + return closeFns.ToFunction(), nil +} + +func (cc *Conn) writeMessageAsync(req *pool.Message) error { + req.UpsertType(message.Confirmable) + req.UpsertMessageID(cc.GetMessageID()) + closeFn, err := cc.prepareWriteMessage(req, func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + // do nothing + }) + if err != nil { + return err + } + defer closeFn() + if err := cc.session.WriteMessage(req); err != nil { + return fmt.Errorf(errFmtWriteRequest, err) + } + return nil +} + +func (cc *Conn) writeMessage(req *pool.Message) error { + req.UpsertType(message.Confirmable) + req.UpsertMessageID(cc.GetMessageID()) + if req.Type() != message.Confirmable { + return cc.writeMessageAsync(req) + } + respChan := make(chan struct{}) + closeFn, err := cc.prepareWriteMessage(req, func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + close(respChan) + }) + if err != nil { + return err + } + defer closeFn() + if err := cc.session.WriteMessage(req); err != nil { + return fmt.Errorf(errFmtWriteRequest, err) + } + if err := cc.waitForAcknowledge(req, respChan); err != nil { + return fmt.Errorf(errFmtWriteRequest, err) + } + return nil +} + +// WriteMessage sends an coap message. +func (cc *Conn) WriteMessage(req *pool.Message) error { + if cc.blockWise == nil { + return cc.writeMessage(req) + } + return cc.blockWise.WriteMessage(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(bwReq *pool.Message) error { + if bwReq.Options().HasOption(message.Block1) || bwReq.Options().HasOption(message.Block2) { + bwReq.SetMessageID(cc.GetMessageID()) + } + return cc.writeMessage(bwReq) + }) +} + +// Context returns the client's context. +// +// If connections was closed context is cancelled. +func (cc *Conn) Context() context.Context { + return cc.session.Context() +} + +// AsyncPing sends ping and receivedPong will be called when pong arrives. It returns cancellation of ping operation. +func (cc *Conn) AsyncPing(receivedPong func()) (func(), error) { + req := cc.AcquireMessage(cc.Context()) + req.SetType(message.Confirmable) + req.SetCode(codes.Empty) + mid := cc.GetMessageID() + req.SetMessageID(mid) + if _, loaded := cc.midHandlerContainer.LoadOrStore(mid, &midElement{ + handler: func(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + if r.Type() == message.Reset || r.Type() == message.Acknowledgement { + receivedPong() + } + }, + start: time.Now(), + deadline: time.Time{}, // no deadline + private: struct { + sync.Mutex + msg *pool.Message + }{msg: req}, + }); loaded { + return nil, fmt.Errorf("cannot insert mid(%v) handler: %w", mid, coapErrors.ErrKeyAlreadyExists) + } + removeMidHandler := func() { + if elem, ok := cc.midHandlerContainer.LoadAndDelete(mid); ok { + elem.ReleaseMessage(cc) + } + } + if err := cc.session.WriteMessage(req); err != nil { + removeMidHandler() + return nil, fmt.Errorf(errFmtWriteRequest, err) + } + return removeMidHandler, nil +} + +// Run reads and process requests from a connection, until the connection is closed. +func (cc *Conn) Run() error { + return cc.session.Run(cc) +} + +// AddOnClose calls function on close connection event. +func (cc *Conn) AddOnClose(f EventFunc) { + cc.session.AddOnClose(f) +} + +func (cc *Conn) RemoteAddr() net.Addr { + return cc.session.RemoteAddr() +} + +func (cc *Conn) LocalAddr() net.Addr { + return cc.session.LocalAddr() +} + +func (cc *Conn) sendPong(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + if err := w.SetResponse(codes.Empty, message.TextPlain, nil); err != nil { + cc.errors(fmt.Errorf("cannot send pong response: %w", err)) + } + if r.Type() == message.Confirmable { + w.Message().SetType(message.Acknowledgement) + w.Message().SetMessageID(r.MessageID()) + } else { + if w.Message().Type() != message.Reset { + w.Message().SetType(message.NonConfirmable) + } + w.Message().SetMessageID(cc.GetMessageID()) + } +} + +func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], m *pool.Message) { + if m.IsSeparateMessage() { + // msg was processed by token handler - just drop it. + return + } + if cc.blockWise != nil { + cc.blockWise.Handle(w, m, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(rw *responsewriter.ResponseWriter[*Conn], rm *pool.Message) { + if h, ok := cc.tokenHandlerContainer.LoadAndDelete(rm.Token().Hash()); ok { + h(rw, rm) + return + } + cc.observationHandler.Handle(rw, rm) + }) + return + } + if h, ok := cc.tokenHandlerContainer.LoadAndDelete(m.Token().Hash()); ok { + h(w, m) + return + } + cc.observationHandler.Handle(w, m) +} + +// Sequence acquires sequence number. +func (cc *Conn) Sequence() uint64 { + return cc.sequence.Add(1) +} + +func (cc *Conn) responseMsgCacheID(msgID int32) string { + return fmt.Sprintf("resp-%v-%d", cc.RemoteAddr(), msgID) +} + +func (cc *Conn) addResponseToCache(resp *pool.Message) error { + marshaledResp, err := resp.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + return err + } + cacheMsg := make([]byte, len(marshaledResp)) + copy(cacheMsg, marshaledResp) + cc.responseMsgCache.LoadOrStore(cc.responseMsgCacheID(resp.MessageID()), cache.NewElement(cacheMsg, time.Now().Add(ExchangeLifetime), nil)) + return nil +} + +func (cc *Conn) getResponseFromCache(mid int32, resp *pool.Message) (bool, error) { + cachedResp := cc.responseMsgCache.Load(cc.responseMsgCacheID(mid)) + if cachedResp == nil { + return false, nil + } + if rawMsg := cachedResp.Data(); len(rawMsg) > 0 { + _, err := resp.UnmarshalWithDecoder(coder.DefaultCoder, rawMsg) + if err != nil { + return false, err + } + return true, nil + } + return false, nil +} + +// checkMyMessageID compare client msgID against peer messageID and if it is near < 0xffff/4 then incrase msgID. +// When msgIDs met it can cause issue because cache can send message to which doesn't bellows to request. +func (cc *Conn) checkMyMessageID(req *pool.Message) { + if req.Type() == message.Confirmable { + for { + oldID := cc.msgID.Load() + if uint16(req.MessageID())-uint16(cc.msgID.Load()) >= 0xffff/4 { + return + } + newID := oldID + 0xffff/2 + if cc.msgID.CompareAndSwap(oldID, newID) { + break + } + } + } +} + +func (cc *Conn) checkResponseCache(req *pool.Message, w *responsewriter.ResponseWriter[*Conn]) (bool, error) { + if req.Type() == message.Confirmable || req.Type() == message.NonConfirmable { + if ok, err := cc.getResponseFromCache(req.MessageID(), w.Message()); ok { + w.Message().SetMessageID(req.MessageID()) + w.Message().SetType(message.NonConfirmable) + if req.Type() == message.Confirmable { + // req could be changed from NonConfirmation to confirmation message. + w.Message().SetType(message.Acknowledgement) + } + return true, nil + } else if err != nil { + return false, fmt.Errorf("cannot unmarshal response from cache: %w", err) + } + } + return false, nil +} + +func isPongOrResetResponse(w *responsewriter.ResponseWriter[*Conn]) bool { + return w.Message().IsModified() && (w.Message().Type() == message.Reset || w.Message().Code() == codes.Empty) +} + +func sendJustAcknowledgeMessage(reqType message.Type, w *responsewriter.ResponseWriter[*Conn]) bool { + return reqType == message.Confirmable && !w.Message().IsModified() +} + +func (cc *Conn) processResponse(reqType message.Type, reqMessageID int32, w *responsewriter.ResponseWriter[*Conn]) error { + switch { + case isPongOrResetResponse(w): + if reqType == message.Confirmable { + w.Message().SetType(message.Acknowledgement) + w.Message().SetMessageID(reqMessageID) + } else { + if w.Message().Type() != message.Reset { + w.Message().SetType(message.NonConfirmable) + } + w.Message().SetMessageID(cc.GetMessageID()) + } + return nil + case sendJustAcknowledgeMessage(reqType, w): + // send message to separate(confirm received) message, if response is not modified + w.Message().SetCode(codes.Empty) + w.Message().SetType(message.Acknowledgement) + w.Message().SetMessageID(reqMessageID) + w.Message().SetToken(nil) + err := cc.addResponseToCache(w.Message()) + if err != nil { + return fmt.Errorf("cannot cache response: %w", err) + } + return nil + case !w.Message().IsModified(): + // don't send response + return nil + } + + // send piggybacked response + w.Message().SetType(message.Confirmable) + w.Message().SetMessageID(cc.GetMessageID()) + if reqType == message.Confirmable { + w.Message().SetType(message.Acknowledgement) + w.Message().SetMessageID(reqMessageID) + } + if reqType == message.Confirmable || reqType == message.NonConfirmable { + err := cc.addResponseToCache(w.Message()) + if err != nil { + return fmt.Errorf("cannot cache response: %w", err) + } + } + return nil +} + +func (cc *Conn) handleReq(w *responsewriter.ResponseWriter[*Conn], req *pool.Message) { + defer cc.inactivityMonitor.Notify() + reqMid := req.MessageID() + + // The same message ID can not be handled concurrently + // for deduplication to work + l := cc.msgIDMutex.Lock(reqMid) + defer l.Unlock() + + if ok, err := cc.checkResponseCache(req, w); err != nil { + cc.closeConnection() + cc.errors(fmt.Errorf(errFmtWriteResponse, err)) + return + } else if ok { + return + } + + w.Message().SetModified(false) + reqType := req.Type() + reqMessageID := req.MessageID() + cc.handle(w, req) + + err := cc.processResponse(reqType, reqMessageID, w) + if err != nil { + cc.closeConnection() + cc.errors(fmt.Errorf(errFmtWriteResponse, err)) + } +} + +func (cc *Conn) closeConnection() { + if errC := cc.Close(); errC != nil { + cc.errors(fmt.Errorf("cannot close connection: %w", errC)) + } +} + +func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler config.HandlerFunc[*Conn]) { + defer func() { + if !req.IsHijacked() { + cc.ReleaseMessage(req) + } + }() + resp := cc.AcquireMessage(cc.Context()) + resp.SetToken(req.Token()) + w := responsewriter.New(resp, cc, req.Options()...) + defer func() { + cc.ReleaseMessage(w.Message()) + }() + handler(w, req) + select { + case <-cc.Context().Done(): + return + default: + } + if !w.Message().IsModified() { + // nothing to send + return + } + errW := cc.writeMessageAsync(w.Message()) + if errW != nil { + cc.closeConnection() + cc.errors(fmt.Errorf(errFmtWriteResponse, errW)) + } +} + +func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { + cc.sendPong(w, r) +} + +func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { + // ping request + if r.Code() == codes.Empty && r.Type() == message.Confirmable && len(r.Token()) == 0 && len(r.Options()) == 0 && r.Body() == nil { + cc.ProcessReceivedMessageWithHandler(r, cc.handlePong) + return true + } + // if waits for concrete message handler + if elem, ok := cc.midHandlerContainer.LoadAndDelete(r.MessageID()); ok { + elem.ReleaseMessage(cc) + resp := cc.AcquireMessage(cc.Context()) + resp.SetToken(r.Token()) + w := responsewriter.New(resp, cc, r.Options()...) + defer func() { + cc.ReleaseMessage(w.Message()) + }() + elem.handler(w, r) + // we just confirmed that message was processed for cc.writeMessage + // the body of the message is need to be processed by the loopOverReceivedMessageQueue goroutine + return false + } + // separate message + if r.IsSeparateMessage() { + // msg was processed by token handler - just drop it. + return true + } + return false +} + +func (cc *Conn) Process(datagram []byte) error { + if uint32(len(datagram)) > cc.session.MaxMessageSize() { + return fmt.Errorf("max message size(%v) was exceeded %v", cc.session.MaxMessageSize(), len(datagram)) + } + req := cc.AcquireMessage(cc.Context()) + _, err := req.UnmarshalWithDecoder(coder.DefaultCoder, datagram) + if err != nil { + cc.ReleaseMessage(req) + return err + } + req.SetSequence(cc.Sequence()) + cc.checkMyMessageID(req) + cc.inactivityMonitor.Notify() + if cc.handleSpecialMessages(req) { + return nil + } + select { + case cc.receivedMessageReader.C() <- req: + case <-cc.Context().Done(): + } + return nil +} + +// SetContextValue stores the value associated with key to context of connection. +func (cc *Conn) SetContextValue(key interface{}, val interface{}) { + cc.session.SetContextValue(key, val) +} + +// Done signalizes that connection is not more processed. +func (cc *Conn) Done() <-chan struct{} { + return cc.session.Done() +} + +func (cc *Conn) checkMidHandlerContainer(now time.Time, maxRetransmit int32, acknowledgeTimeout time.Duration, key int32, value *midElement) { + if value.IsExpired(now, maxRetransmit) { + cc.midHandlerContainer.Delete(key) + value.ReleaseMessage(cc) + cc.errors(fmt.Errorf(errFmtWriteRequest, context.DeadlineExceeded)) + return + } + if !value.Retransmit(now, acknowledgeTimeout) { + return + } + msg, ok, err := value.GetMessage(cc) + if err != nil { + cc.midHandlerContainer.Delete(key) + value.ReleaseMessage(cc) + cc.errors(fmt.Errorf(errFmtWriteRequest, err)) + return + } + if ok { + defer cc.ReleaseMessage(msg) + err := cc.session.WriteMessage(msg) + if err != nil { + cc.errors(fmt.Errorf(errFmtWriteRequest, err)) + } + } +} + +// CheckExpirations checks and remove expired items from caches. +func (cc *Conn) CheckExpirations(now time.Time) { + cc.inactivityMonitor.CheckInactivity(now, cc) + cc.responseMsgCache.CheckExpirations(now) + if cc.blockWise != nil { + cc.blockWise.CheckExpirations(now) + } + maxRetransmit := cc.transmission.maxRetransmit.Load() + acknowledgeTimeout := cc.transmission.acknowledgeTimeout.Load() + x := struct { + now time.Time + maxRetransmit int32 + acknowledgeTimeout time.Duration + cc *Conn + }{ + now: now, + maxRetransmit: maxRetransmit, + acknowledgeTimeout: acknowledgeTimeout, + cc: cc, + } + cc.midHandlerContainer.Range(func(key int32, value *midElement) bool { + x.cc.checkMidHandlerContainer(x.now, x.maxRetransmit, x.acknowledgeTimeout, key, value) + return true + }) +} + +func (cc *Conn) AcquireMessage(ctx context.Context) *pool.Message { + return cc.messagePool.AcquireMessage(ctx) +} + +func (cc *Conn) ReleaseMessage(m *pool.Message) { + cc.messagePool.ReleaseMessage(m) +} + +// WriteMulticastMessage sends multicast to the remote multicast address. +// By default it is sent over all network interfaces and all compatible source IP addresses with hop limit 1. +// Via opts you can specify the network interface, source IP address, and hop limit. +func (cc *Conn) WriteMulticastMessage(req *pool.Message, address *net.UDPAddr, options ...coapNet.MulticastOption) error { + if req.Type() == message.Confirmable { + return fmt.Errorf("multicast messages cannot be confirmable") + } + req.UpsertMessageID(cc.GetMessageID()) + + err := cc.session.WriteMulticastMessage(req, address, options...) + if err != nil { + return fmt.Errorf(errFmtWriteRequest, err) + } + return nil +} + +func (cc *Conn) InactivityMonitor() InactivityMonitor { + return cc.inactivityMonitor +} + +// NetConn returns the underlying connection that is wrapped by cc. The Conn returned is shared by all invocations of NetConn, so do not modify it. +func (cc *Conn) NetConn() net.Conn { + return cc.session.NetConn() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/client/mutexmap.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/client/mutexmap.go new file mode 100644 index 0000000000..6665aceab3 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/client/mutexmap.go @@ -0,0 +1,71 @@ +package client + +import ( + "fmt" + "sync" +) + +// MutexMap wraps a map of mutexes. Each key locks separately. +type MutexMap struct { + ma map[interface{}]*mutexMapEntry // entry map + ml sync.Mutex // lock for entry map +} + +type mutexMapEntry struct { + key interface{} // key in ma + m *MutexMap // point back to MutexMap, so we can synchronize removing this mutexMapEntry when cnt==0 + el sync.Mutex // entry-specific lock + cnt uint16 // reference count +} + +// Unlocker provides an Unlock method to release the lock. +type Unlocker interface { + Unlock() +} + +// NewMutexMap returns an initialized MutexMap. +func NewMutexMap() *MutexMap { + return &MutexMap{ma: make(map[interface{}]*mutexMapEntry)} +} + +// Lock acquires a lock corresponding to this key. +// This method will never return nil and Unlock() must be called +// to release the lock when done. +func (m *MutexMap) Lock(key interface{}) Unlocker { + // read or create entry for this key atomically + m.ml.Lock() + e, ok := m.ma[key] + if !ok { + e = &mutexMapEntry{m: m, key: key} + m.ma[key] = e + } + e.cnt++ // ref count + m.ml.Unlock() + + // acquire lock, will block here until e.cnt==1 + e.el.Lock() + + return e +} + +// Unlock releases the lock for this entry. +func (entry *mutexMapEntry) Unlock() { + m := entry.m + + // decrement and if needed remove entry atomically + m.ml.Lock() + e, ok := m.ma[entry.key] + if !ok { // entry must exist + m.ml.Unlock() + panic(fmt.Errorf("unlock requested for key=%v but no entry found", entry.key)) + } + e.cnt-- // ref count + if e.cnt < 1 { // if it hits zero then we own it and remove from map + delete(m.ma, entry.key) + } + m.ml.Unlock() + + // now that map stuff is handled, we unlock and let + // anything else waiting on this key through + e.el.Unlock() +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/coder/coder.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/coder/coder.go new file mode 100644 index 0000000000..4b1f0d35de --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/coder/coder.go @@ -0,0 +1,140 @@ +package coder + +import ( + "encoding/binary" + "errors" + "fmt" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" +) + +var DefaultCoder = new(Coder) + +type Coder struct{} + +func (c *Coder) Size(m message.Message) (int, error) { + if len(m.Token) > message.MaxTokenSize { + return -1, message.ErrInvalidTokenLen + } + size := 4 + len(m.Token) + payloadLen := len(m.Payload) + optionsLen, err := m.Options.Marshal(nil) + if !errors.Is(err, message.ErrTooSmall) { + return -1, err + } + if payloadLen > 0 { + // for separator 0xff + payloadLen++ + } + size += payloadLen + optionsLen + return size, nil +} + +func (c *Coder) Encode(m message.Message, buf []byte) (int, error) { + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |Ver| T | TKL | Code | Message ID | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Token (if any, TKL bytes) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Options (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |1 1 1 1 1 1 1 1| Payload (if any) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + if !message.ValidateMID(m.MessageID) { + return -1, fmt.Errorf("invalid MessageID(%v)", m.MessageID) + } + if !message.ValidateType(m.Type) { + return -1, fmt.Errorf("invalid Type(%v)", m.Type) + } + size, err := c.Size(m) + if err != nil { + return -1, err + } + if len(buf) < size { + return size, message.ErrTooSmall + } + + tmpbuf := []byte{0, 0} + binary.BigEndian.PutUint16(tmpbuf, uint16(m.MessageID)) + + buf[0] = (1 << 6) | byte(m.Type)<<4 | byte(0xf&len(m.Token)) + buf[1] = byte(m.Code) + buf[2] = tmpbuf[0] + buf[3] = tmpbuf[1] + buf = buf[4:] + + if len(m.Token) > message.MaxTokenSize { + return -1, message.ErrInvalidTokenLen + } + copy(buf, m.Token) + buf = buf[len(m.Token):] + + optionsLen, err := m.Options.Marshal(buf) + switch { + case err == nil: + case errors.Is(err, message.ErrTooSmall): + return size, err + default: + return -1, err + } + buf = buf[optionsLen:] + + if len(m.Payload) > 0 { + buf[0] = 0xff + buf = buf[1:] + } + copy(buf, m.Payload) + return size, nil +} + +func (c *Coder) Decode(data []byte, m *message.Message) (int, error) { + size := len(data) + if size < 4 { + return -1, ErrMessageTruncated + } + + if data[0]>>6 != 1 { + return -1, ErrMessageInvalidVersion + } + + typ := message.Type((data[0] >> 4) & 0x3) + tokenLen := int(data[0] & 0xf) + if tokenLen > 8 { + return -1, message.ErrInvalidTokenLen + } + + code := codes.Code(data[1]) + messageID := binary.BigEndian.Uint16(data[2:4]) + data = data[4:] + if len(data) < tokenLen { + return -1, ErrMessageTruncated + } + token := data[:tokenLen] + if len(token) == 0 { + token = nil + } + data = data[tokenLen:] + + optionDefs := message.CoapOptionDefs + proc, err := m.Options.Unmarshal(data, optionDefs) + if err != nil { + return -1, err + } + data = data[proc:] + if len(data) == 0 { + data = nil + } + + m.Payload = data + m.Code = code + m.Token = token + m.Type = typ + m.MessageID = int32(messageID) + + return size, nil +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/coder/error.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/coder/error.go new file mode 100644 index 0000000000..ac7ae99aad --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/coder/error.go @@ -0,0 +1,8 @@ +package coder + +import "errors" + +var ( + ErrMessageTruncated = errors.New("message is truncated") + ErrMessageInvalidVersion = errors.New("message has invalid version") +) diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/server.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/server.go new file mode 100644 index 0000000000..767bfffae1 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/server.go @@ -0,0 +1,7 @@ +package udp + +import "github.com/plgd-dev/go-coap/v3/udp/server" + +func NewServer(opt ...server.Option) *server.Server { + return server.New(opt...) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/server/config.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/config.go new file mode 100644 index 0000000000..7cfb498e3d --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/config.go @@ -0,0 +1,64 @@ +package server + +import ( + "fmt" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/message/pool" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/options/config" + udpClient "github.com/plgd-dev/go-coap/v3/udp/client" +) + +// The HandlerFunc type is an adapter to allow the use of +// ordinary functions as COAP handlers. +type HandlerFunc = func(*responsewriter.ResponseWriter[*udpClient.Conn], *pool.Message) + +type ErrorFunc = func(error) + +// OnNewConnFunc is the callback for new connections. +type OnNewConnFunc = func(cc *udpClient.Conn) + +type GetMIDFunc = func() int32 + +var DefaultConfig = func() Config { + opts := Config{ + Common: config.NewCommon[*udpClient.Conn](), + CreateInactivityMonitor: func() udpClient.InactivityMonitor { + timeout := time.Second * 16 + onInactive := func(cc *udpClient.Conn) { + _ = cc.Close() + } + return inactivity.New(timeout, onInactive) + }, + OnNewConn: func(cc *udpClient.Conn) { + // do nothing by default + }, + TransmissionNStart: 1, + TransmissionAcknowledgeTimeout: time.Second * 2, + TransmissionMaxRetransmit: 4, + GetMID: message.GetMID, + MTU: udpClient.DefaultMTU, + } + opts.Handler = func(w *responsewriter.ResponseWriter[*udpClient.Conn], r *pool.Message) { + if err := w.SetResponse(codes.NotFound, message.TextPlain, nil); err != nil { + opts.Errors(fmt.Errorf("udp server: cannot set response: %w", err)) + } + } + return opts +}() + +type Config struct { + config.Common[*udpClient.Conn] + CreateInactivityMonitor udpClient.CreateInactivityMonitorFunc + GetMID GetMIDFunc + Handler HandlerFunc + OnNewConn OnNewConnFunc + TransmissionNStart uint32 + TransmissionAcknowledgeTimeout time.Duration + TransmissionMaxRetransmit uint32 + MTU uint16 +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/server/discover.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/discover.go new file mode 100644 index 0000000000..aa41af4ce8 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/discover.go @@ -0,0 +1,90 @@ +package server + +import ( + "context" + "fmt" + "net" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/pkg/errors" + "github.com/plgd-dev/go-coap/v3/udp/client" + "github.com/plgd-dev/go-coap/v3/udp/coder" +) + +// Discover sends GET to multicast or unicast address and waits for responses until context timeouts or server shutdown. +// For unicast there is a difference against the Dial. The Dial is connection-oriented and it means that, if you send a request to an address, the peer must send the response from the same +// address where was request sent. For Discover it allows the client to send a response from another address where was request send. +// By default it is sent over all network interfaces and all compatible source IP addresses with hop limit 1. +// Via opts you can specify the network interface, source IP address, and hop limit. +func (s *Server) Discover(ctx context.Context, address, path string, receiverFunc func(cc *client.Conn, resp *pool.Message), opts ...coapNet.MulticastOption) error { + token, err := s.cfg.GetToken() + if err != nil { + return fmt.Errorf("cannot get token: %w", err) + } + req := s.cfg.MessagePool.AcquireMessage(ctx) + defer s.cfg.MessagePool.ReleaseMessage(req) + err = req.SetupGet(path, token) + if err != nil { + return fmt.Errorf("cannot create discover request: %w", err) + } + req.SetMessageID(s.cfg.GetMID()) + req.SetType(message.NonConfirmable) + return s.DiscoveryRequest(req, address, receiverFunc, opts...) +} + +// DiscoveryRequest sends request to multicast/unicast address and wait for responses until request timeouts or server shutdown. +// For unicast there is a difference against the Dial. The Dial is connection-oriented and it means that, if you send a request to an address, the peer must send the response from the same +// address where was request sent. For Discover it allows the client to send a response from another address where was request send. +// By default it is sent over all network interfaces and all compatible source IP addresses with hop limit 1. +// Via opts you can specify the network interface, source IP address, and hop limit. +func (s *Server) DiscoveryRequest(req *pool.Message, address string, receiverFunc func(cc *client.Conn, resp *pool.Message), opts ...coapNet.MulticastOption) error { + token := req.Token() + if len(token) == 0 { + return fmt.Errorf("invalid token") + } + c := s.conn() + if c == nil { + return fmt.Errorf("server doesn't serve connection") + } + addr, err := net.ResolveUDPAddr(c.Network(), address) + if err != nil { + return fmt.Errorf("cannot resolve address: %w", err) + } + + data, err := req.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + return fmt.Errorf("cannot marshal req: %w", err) + } + s.multicastRequests.Store(token.Hash(), req) + defer s.multicastRequests.Delete(token.Hash()) + if _, loaded := s.multicastHandler.LoadOrStore(token.Hash(), func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { + receiverFunc(w.Conn(), r) + }); loaded { + return errors.ErrKeyAlreadyExists + } + defer func() { + _, _ = s.multicastHandler.LoadAndDelete(token.Hash()) + }() + + if addr.IP.IsMulticast() { + err = c.WriteMulticast(req.Context(), addr, data, opts...) + if err != nil { + return err + } + } else { + err = c.WriteWithContext(req.Context(), addr, data) + if err != nil { + return err + } + } + + select { + case <-req.Context().Done(): + return nil + case <-s.ctx.Done(): + return fmt.Errorf("server was closed: %w", s.ctx.Err()) + } +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/server/server.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/server.go new file mode 100644 index 0000000000..8bceeab5a6 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/server.go @@ -0,0 +1,383 @@ +package server + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + "github.com/plgd-dev/go-coap/v3/pkg/cache" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "github.com/plgd-dev/go-coap/v3/udp/client" +) + +type Server struct { + doneCtx context.Context + ctx context.Context + multicastRequests *client.RequestsMap + multicastHandler *coapSync.Map[uint64, HandlerFunc] + serverStartedChan chan struct{} + doneCancel context.CancelFunc + cancel context.CancelFunc + responseMsgCache *cache.Cache[string, []byte] + + connsMutex sync.Mutex + conns map[string]*client.Conn + + listenMutex sync.Mutex + listen *coapNet.UDPConn + + cfg *Config +} + +// A Option sets options such as credentials, codec and keepalive parameters, etc. +type Option interface { + UDPServerApply(cfg *Config) +} + +func New(opt ...Option) *Server { + cfg := DefaultConfig + for _, o := range opt { + o.UDPServerApply(&cfg) + } + + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + + if cfg.GetMID == nil { + cfg.GetMID = message.GetMID + } + + if cfg.GetToken == nil { + cfg.GetToken = message.GetToken + } + + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() client.InactivityMonitor { + return inactivity.NewNilMonitor[*client.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + + ctx, cancel := context.WithCancel(cfg.Ctx) + serverStartedChan := make(chan struct{}) + + doneCtx, doneCancel := context.WithCancel(context.Background()) + errorsFunc := cfg.Errors + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("udp: %w", err)) + } + return &Server{ + ctx: ctx, + cancel: cancel, + multicastHandler: coapSync.NewMap[uint64, HandlerFunc](), + multicastRequests: coapSync.NewMap[uint64, *pool.Message](), + serverStartedChan: serverStartedChan, + doneCtx: doneCtx, + doneCancel: doneCancel, + responseMsgCache: cache.NewCache[string, []byte](), + conns: make(map[string]*client.Conn), + + cfg: &cfg, + } +} + +func (s *Server) checkAndSetListener(l *coapNet.UDPConn) error { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + if s.listen != nil { + return fmt.Errorf("server already serve: %v", s.listen.LocalAddr().String()) + } + s.listen = l + close(s.serverStartedChan) + return nil +} + +func (s *Server) closeConnection(cc *client.Conn) { + if err := cc.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close connection: %w", err)) + } +} + +func (s *Server) Serve(l *coapNet.UDPConn) error { + if s.cfg.BlockwiseSZX > blockwise.SZX1024 { + return fmt.Errorf("invalid blockwiseSZX") + } + + err := s.checkAndSetListener(l) + if err != nil { + return err + } + + defer func() { + s.closeSessions() + s.doneCancel() + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + s.listen = nil + s.serverStartedChan = make(chan struct{}, 1) + }() + + m := make([]byte, s.cfg.MaxMessageSize) + var wg sync.WaitGroup + + s.cfg.PeriodicRunner(func(now time.Time) bool { + s.handleInactivityMonitors(now) + s.responseMsgCache.CheckExpirations(now) + return s.ctx.Err() == nil + }) + + for { + buf := m + n, raddr, err := l.ReadWithContext(s.ctx, buf) + if err != nil { + wg.Wait() + + select { + case <-s.ctx.Done(): + return nil + default: + if coapNet.IsCancelOrCloseError(err) { + return nil + } + return err + } + } + buf = buf[:n] + cc, err := s.getConn(l, raddr, true) + if err != nil { + s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) + continue + } + err = cc.Process(buf) + if err != nil { + s.closeConnection(cc) + s.cfg.Errors(fmt.Errorf("%v: cannot process packet: %w", cc.RemoteAddr(), err)) + } + } +} + +func (s *Server) getListener() *coapNet.UDPConn { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + return s.listen +} + +// Stop stops server without wait of ends Serve function. +func (s *Server) Stop() { + s.cancel() + l := s.getListener() + if l != nil { + if errC := l.Close(); errC != nil { + s.cfg.Errors(fmt.Errorf("cannot close listener: %w", errC)) + } + } + s.closeSessions() +} + +func (s *Server) closeSessions() { + s.connsMutex.Lock() + conns := s.conns + s.conns = make(map[string]*client.Conn) + s.connsMutex.Unlock() + for _, cc := range conns { + s.closeConnection(cc) + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + } +} + +func (s *Server) conn() *coapNet.UDPConn { + s.listenMutex.Lock() + serverStartedChan := s.serverStartedChan + s.listenMutex.Unlock() + select { + case <-serverStartedChan: + case <-s.ctx.Done(): + } + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + return s.listen +} + +const closeKey = "gocoapCloseConnection" + +func (s *Server) getConns() []*client.Conn { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + conns := make([]*client.Conn, 0, 32) + for _, c := range s.conns { + conns = append(conns, c) + } + return conns +} + +func (s *Server) handleInactivityMonitors(now time.Time) { + for _, cc := range s.getConns() { + select { + case <-cc.Context().Done(): + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + continue + default: + cc.CheckExpirations(now) + } + } +} + +func getClose(cc *client.Conn) func() { + v := cc.Context().Value(closeKey) + if v == nil { + return nil + } + closeFn, ok := v.(func()) + if !ok { + panic(fmt.Errorf("invalid type(%T) of context value for key %s", v, closeKey)) + } + return closeFn +} + +func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (cc *client.Conn, created bool) { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + key := raddr.String() + cc = s.conns[key] + + if cc != nil { + return cc, false + } + + createBlockWise := func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + return nil + } + if s.cfg.BlockwiseEnable { + createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + v := cc + return blockwise.New( + v, + s.cfg.BlockwiseTransferTimeout, + s.cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + msg, ok := v.GetObservationRequest(token) + if ok { + return msg, ok + } + return s.multicastRequests.LoadWithFunc(token.Hash(), func(m *pool.Message) *pool.Message { + msg := v.AcquireMessage(m.Context()) + msg.ResetOptionsTo(m.Options()) + msg.SetCode(m.Code()) + msg.SetToken(m.Token()) + msg.SetMessageID(m.MessageID()) + return msg + }) + }) + } + } + session := NewSession( + s.ctx, + s.doneCtx, + udpConn, + raddr, + s.cfg.MaxMessageSize, + s.cfg.MTU, + false, + ) + monitor := s.cfg.CreateInactivityMonitor() + cfg := client.DefaultConfig + cfg.TransmissionNStart = s.cfg.TransmissionNStart + cfg.TransmissionAcknowledgeTimeout = s.cfg.TransmissionAcknowledgeTimeout + cfg.TransmissionMaxRetransmit = s.cfg.TransmissionMaxRetransmit + cfg.Handler = func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { + h, ok := s.multicastHandler.Load(r.Token().Hash()) + if ok { + h(w, r) + return + } + s.cfg.Handler(w, r) + } + cfg.BlockwiseSZX = s.cfg.BlockwiseSZX + cfg.Errors = s.cfg.Errors + cfg.GetMID = s.cfg.GetMID + cfg.GetToken = s.cfg.GetToken + cfg.MessagePool = s.cfg.MessagePool + cfg.ProcessReceivedMessage = s.cfg.ProcessReceivedMessage + cfg.ReceivedMessageQueueSize = s.cfg.ReceivedMessageQueueSize + + cc = client.NewConn( + session, + createBlockWise, + monitor, + &cfg, + ) + cc.SetContextValue(closeKey, func() { + if err := session.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close session: %w", err)) + } + session.shutdown() + }) + cc.AddOnClose(func() { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + if cc == s.conns[key] { + delete(s.conns, key) + } + }) + s.conns[key] = cc + return cc, true +} + +func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { + cc, created := s.getOrCreateConn(l, raddr) + if created { + if s.cfg.OnNewConn != nil { + s.cfg.OnNewConn(cc) + } + } else { + // check if client is not expired now + 10ms - if so, close it + // 10ms - The expected maximum time taken by cc.CheckExpirations and cc.InactivityMonitor().Notify() + cc.CheckExpirations(time.Now().Add(10 * time.Millisecond)) + if cc.Context().Err() == nil { + // if client is not closed, extend expiration time + cc.InactivityMonitor().Notify() + } + } + + if cc.Context().Err() != nil { + // connection is closed so we need to create new one + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + if firstTime { + return s.getConn(l, raddr, false) + } + return nil, fmt.Errorf("connection is closed") + } + return cc, nil +} + +func (s *Server) NewConn(addr *net.UDPAddr) (*client.Conn, error) { + l := s.getListener() + if l == nil { + // server is not started/stopped + return nil, fmt.Errorf("server is not running") + } + return s.getConn(l, addr, true) +} diff --git a/vendor/github.com/plgd-dev/go-coap/v3/udp/server/session.go b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/session.go new file mode 100644 index 0000000000..870ff53c10 --- /dev/null +++ b/vendor/github.com/plgd-dev/go-coap/v3/udp/server/session.go @@ -0,0 +1,165 @@ +package server + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/udp/client" + "github.com/plgd-dev/go-coap/v3/udp/coder" +) + +type EventFunc = func() + +type Session struct { + onClose []EventFunc + + ctx atomic.Value // TODO: change to atomic.Pointer[context.Context] for go1.19 + + doneCtx context.Context + connection *coapNet.UDPConn + doneCancel context.CancelFunc + + cancel context.CancelFunc + raddr *net.UDPAddr + + mutex sync.Mutex + maxMessageSize uint32 + mtu uint16 + + closeSocket bool +} + +func NewSession( + ctx context.Context, + doneCtx context.Context, + connection *coapNet.UDPConn, + raddr *net.UDPAddr, + maxMessageSize uint32, + mtu uint16, + closeSocket bool, +) *Session { + ctx, cancel := context.WithCancel(ctx) + + doneCtx, doneCancel := context.WithCancel(doneCtx) + s := &Session{ + cancel: cancel, + connection: connection, + raddr: raddr, + maxMessageSize: maxMessageSize, + mtu: mtu, + closeSocket: closeSocket, + doneCtx: doneCtx, + doneCancel: doneCancel, + } + s.ctx.Store(&ctx) + return s +} + +// SetContextValue stores the value associated with key to context of connection. +func (s *Session) SetContextValue(key interface{}, val interface{}) { + ctx := context.WithValue(s.Context(), key, val) + s.ctx.Store(&ctx) +} + +// Done signalizes that connection is not more processed. +func (s *Session) Done() <-chan struct{} { + return s.doneCtx.Done() +} + +func (s *Session) AddOnClose(f EventFunc) { + s.mutex.Lock() + defer s.mutex.Unlock() + s.onClose = append(s.onClose, f) +} + +func (s *Session) popOnClose() []EventFunc { + s.mutex.Lock() + defer s.mutex.Unlock() + tmp := s.onClose + s.onClose = nil + return tmp +} + +func (s *Session) shutdown() { + defer s.doneCancel() + for _, f := range s.popOnClose() { + f() + } +} + +func (s *Session) Close() error { + s.cancel() + if s.closeSocket { + return s.connection.Close() + } + return nil +} + +func (s *Session) Context() context.Context { + return *s.ctx.Load().(*context.Context) //nolint:forcetypeassert +} + +func (s *Session) WriteMessage(req *pool.Message) error { + data, err := req.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + return fmt.Errorf("cannot marshal: %w", err) + } + return s.connection.WriteWithContext(req.Context(), s.raddr, data) +} + +// WriteMulticastMessage sends multicast to the remote multicast address. +// By default it is sent over all network interfaces and all compatible source IP addresses with hop limit 1. +// Via opts you can specify the network interface, source IP address, and hop limit. +func (s *Session) WriteMulticastMessage(req *pool.Message, address *net.UDPAddr, opts ...coapNet.MulticastOption) error { + data, err := req.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + return fmt.Errorf("cannot marshal: %w", err) + } + + return s.connection.WriteMulticast(req.Context(), address, data, opts...) +} + +func (s *Session) Run(cc *client.Conn) (err error) { + defer func() { + err1 := s.Close() + if err == nil { + err = err1 + } + s.shutdown() + }() + m := make([]byte, s.mtu) + for { + buf := m + n, _, err := s.connection.ReadWithContext(s.Context(), buf) + if err != nil { + return err + } + buf = buf[:n] + err = cc.Process(buf) + if err != nil { + return err + } + } +} + +func (s *Session) MaxMessageSize() uint32 { + return s.maxMessageSize +} + +func (s *Session) RemoteAddr() net.Addr { + return s.raddr +} + +func (s *Session) LocalAddr() net.Addr { + return s.connection.LocalAddr() +} + +// NetConn returns the underlying connection that is wrapped by s. The Conn returned is shared by all invocations of NetConn, so do not modify it. +func (s *Session) NetConn() net.Conn { + return s.connection.NetConn() +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/counter.go b/vendor/github.com/prometheus/client_golang/prometheus/counter.go new file mode 100644 index 0000000000..62de4dc59a --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/counter.go @@ -0,0 +1,348 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "errors" + "math" + "sync/atomic" + "time" + + dto "github.com/prometheus/client_model/go" +) + +// Counter is a Metric that represents a single numerical value that only ever +// goes up. That implies that it cannot be used to count items whose number can +// also go down, e.g. the number of currently running goroutines. Those +// "counters" are represented by Gauges. +// +// A Counter is typically used to count requests served, tasks completed, errors +// occurred, etc. +// +// To create Counter instances, use NewCounter. +type Counter interface { + Metric + Collector + + // Inc increments the counter by 1. Use Add to increment it by arbitrary + // non-negative values. + Inc() + // Add adds the given value to the counter. It panics if the value is < + // 0. + Add(float64) +} + +// ExemplarAdder is implemented by Counters that offer the option of adding a +// value to the Counter together with an exemplar. Its AddWithExemplar method +// works like the Add method of the Counter interface but also replaces the +// currently saved exemplar (if any) with a new one, created from the provided +// value, the current time as timestamp, and the provided labels. Empty Labels +// will lead to a valid (label-less) exemplar. But if Labels is nil, the current +// exemplar is left in place. AddWithExemplar panics if the value is < 0, if any +// of the provided labels are invalid, or if the provided labels contain more +// than 128 runes in total. +type ExemplarAdder interface { + AddWithExemplar(value float64, exemplar Labels) +} + +// CounterOpts is an alias for Opts. See there for doc comments. +type CounterOpts Opts + +// CounterVecOpts bundles the options to create a CounterVec metric. +// It is mandatory to set CounterOpts, see there for mandatory fields. VariableLabels +// is optional and can safely be left to its default value. +type CounterVecOpts struct { + CounterOpts + + // VariableLabels are used to partition the metric vector by the given set + // of labels. Each label value will be constrained with the optional Contraint + // function, if provided. + VariableLabels ConstrainableLabels +} + +// NewCounter creates a new Counter based on the provided CounterOpts. +// +// The returned implementation also implements ExemplarAdder. It is safe to +// perform the corresponding type assertion. +// +// The returned implementation tracks the counter value in two separate +// variables, a float64 and a uint64. The latter is used to track calls of the +// Inc method and calls of the Add method with a value that can be represented +// as a uint64. This allows atomic increments of the counter with optimal +// performance. (It is common to have an Inc call in very hot execution paths.) +// Both internal tracking values are added up in the Write method. This has to +// be taken into account when it comes to precision and overflow behavior. +func NewCounter(opts CounterOpts) Counter { + desc := NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + nil, + opts.ConstLabels, + ) + result := &counter{desc: desc, labelPairs: desc.constLabelPairs, now: time.Now} + result.init(result) // Init self-collection. + return result +} + +type counter struct { + // valBits contains the bits of the represented float64 value, while + // valInt stores values that are exact integers. Both have to go first + // in the struct to guarantee alignment for atomic operations. + // http://golang.org/pkg/sync/atomic/#pkg-note-BUG + valBits uint64 + valInt uint64 + + selfCollector + desc *Desc + + labelPairs []*dto.LabelPair + exemplar atomic.Value // Containing nil or a *dto.Exemplar. + + now func() time.Time // To mock out time.Now() for testing. +} + +func (c *counter) Desc() *Desc { + return c.desc +} + +func (c *counter) Add(v float64) { + if v < 0 { + panic(errors.New("counter cannot decrease in value")) + } + + ival := uint64(v) + if float64(ival) == v { + atomic.AddUint64(&c.valInt, ival) + return + } + + for { + oldBits := atomic.LoadUint64(&c.valBits) + newBits := math.Float64bits(math.Float64frombits(oldBits) + v) + if atomic.CompareAndSwapUint64(&c.valBits, oldBits, newBits) { + return + } + } +} + +func (c *counter) AddWithExemplar(v float64, e Labels) { + c.Add(v) + c.updateExemplar(v, e) +} + +func (c *counter) Inc() { + atomic.AddUint64(&c.valInt, 1) +} + +func (c *counter) get() float64 { + fval := math.Float64frombits(atomic.LoadUint64(&c.valBits)) + ival := atomic.LoadUint64(&c.valInt) + return fval + float64(ival) +} + +func (c *counter) Write(out *dto.Metric) error { + // Read the Exemplar first and the value second. This is to avoid a race condition + // where users see an exemplar for a not-yet-existing observation. + var exemplar *dto.Exemplar + if e := c.exemplar.Load(); e != nil { + exemplar = e.(*dto.Exemplar) + } + val := c.get() + + return populateMetric(CounterValue, val, c.labelPairs, exemplar, out) +} + +func (c *counter) updateExemplar(v float64, l Labels) { + if l == nil { + return + } + e, err := newExemplar(v, c.now(), l) + if err != nil { + panic(err) + } + c.exemplar.Store(e) +} + +// CounterVec is a Collector that bundles a set of Counters that all share the +// same Desc, but have different values for their variable labels. This is used +// if you want to count the same thing partitioned by various dimensions +// (e.g. number of HTTP requests, partitioned by response code and +// method). Create instances with NewCounterVec. +type CounterVec struct { + *MetricVec +} + +// NewCounterVec creates a new CounterVec based on the provided CounterOpts and +// partitioned by the given label names. +func NewCounterVec(opts CounterOpts, labelNames []string) *CounterVec { + return V2.NewCounterVec(CounterVecOpts{ + CounterOpts: opts, + VariableLabels: UnconstrainedLabels(labelNames), + }) +} + +// NewCounterVec creates a new CounterVec based on the provided CounterVecOpts. +func (v2) NewCounterVec(opts CounterVecOpts) *CounterVec { + desc := V2.NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + opts.VariableLabels, + opts.ConstLabels, + ) + return &CounterVec{ + MetricVec: NewMetricVec(desc, func(lvs ...string) Metric { + if len(lvs) != len(desc.variableLabels) { + panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels.labelNames(), lvs)) + } + result := &counter{desc: desc, labelPairs: MakeLabelPairs(desc, lvs), now: time.Now} + result.init(result) // Init self-collection. + return result + }), + } +} + +// GetMetricWithLabelValues returns the Counter for the given slice of label +// values (same order as the variable labels in Desc). If that combination of +// label values is accessed for the first time, a new Counter is created. +// +// It is possible to call this method without using the returned Counter to only +// create the new Counter but leave it at its starting value 0. See also the +// SummaryVec example. +// +// Keeping the Counter for later use is possible (and should be considered if +// performance is critical), but keep in mind that Reset, DeleteLabelValues and +// Delete can be used to delete the Counter from the CounterVec. In that case, +// the Counter will still exist, but it will not be exported anymore, even if a +// Counter with the same label values is created later. +// +// An error is returned if the number of label values is not the same as the +// number of variable labels in Desc (minus any curried labels). +// +// Note that for more than one label value, this method is prone to mistakes +// caused by an incorrect order of arguments. Consider GetMetricWith(Labels) as +// an alternative to avoid that type of mistake. For higher label numbers, the +// latter has a much more readable (albeit more verbose) syntax, but it comes +// with a performance overhead (for creating and processing the Labels map). +// See also the GaugeVec example. +func (v *CounterVec) GetMetricWithLabelValues(lvs ...string) (Counter, error) { + metric, err := v.MetricVec.GetMetricWithLabelValues(lvs...) + if metric != nil { + return metric.(Counter), err + } + return nil, err +} + +// GetMetricWith returns the Counter for the given Labels map (the label names +// must match those of the variable labels in Desc). If that label map is +// accessed for the first time, a new Counter is created. Implications of +// creating a Counter without using it and keeping the Counter for later use are +// the same as for GetMetricWithLabelValues. +// +// An error is returned if the number and names of the Labels are inconsistent +// with those of the variable labels in Desc (minus any curried labels). +// +// This method is used for the same purpose as +// GetMetricWithLabelValues(...string). See there for pros and cons of the two +// methods. +func (v *CounterVec) GetMetricWith(labels Labels) (Counter, error) { + metric, err := v.MetricVec.GetMetricWith(labels) + if metric != nil { + return metric.(Counter), err + } + return nil, err +} + +// WithLabelValues works as GetMetricWithLabelValues, but panics where +// GetMetricWithLabelValues would have returned an error. Not returning an +// error allows shortcuts like +// +// myVec.WithLabelValues("404", "GET").Add(42) +func (v *CounterVec) WithLabelValues(lvs ...string) Counter { + c, err := v.GetMetricWithLabelValues(lvs...) + if err != nil { + panic(err) + } + return c +} + +// With works as GetMetricWith, but panics where GetMetricWithLabels would have +// returned an error. Not returning an error allows shortcuts like +// +// myVec.With(prometheus.Labels{"code": "404", "method": "GET"}).Add(42) +func (v *CounterVec) With(labels Labels) Counter { + c, err := v.GetMetricWith(labels) + if err != nil { + panic(err) + } + return c +} + +// CurryWith returns a vector curried with the provided labels, i.e. the +// returned vector has those labels pre-set for all labeled operations performed +// on it. The cardinality of the curried vector is reduced accordingly. The +// order of the remaining labels stays the same (just with the curried labels +// taken out of the sequence – which is relevant for the +// (GetMetric)WithLabelValues methods). It is possible to curry a curried +// vector, but only with labels not yet used for currying before. +// +// The metrics contained in the CounterVec are shared between the curried and +// uncurried vectors. They are just accessed differently. Curried and uncurried +// vectors behave identically in terms of collection. Only one must be +// registered with a given registry (usually the uncurried version). The Reset +// method deletes all metrics, even if called on a curried vector. +func (v *CounterVec) CurryWith(labels Labels) (*CounterVec, error) { + vec, err := v.MetricVec.CurryWith(labels) + if vec != nil { + return &CounterVec{vec}, err + } + return nil, err +} + +// MustCurryWith works as CurryWith but panics where CurryWith would have +// returned an error. +func (v *CounterVec) MustCurryWith(labels Labels) *CounterVec { + vec, err := v.CurryWith(labels) + if err != nil { + panic(err) + } + return vec +} + +// CounterFunc is a Counter whose value is determined at collect time by calling a +// provided function. +// +// To create CounterFunc instances, use NewCounterFunc. +type CounterFunc interface { + Metric + Collector +} + +// NewCounterFunc creates a new CounterFunc based on the provided +// CounterOpts. The value reported is determined by calling the given function +// from within the Write method. Take into account that metric collection may +// happen concurrently. If that results in concurrent calls to Write, like in +// the case where a CounterFunc is directly registered with Prometheus, the +// provided function must be concurrency-safe. The function should also honor +// the contract for a Counter (values only go up, not down), but compliance will +// not be checked. +// +// Check out the ExampleGaugeFunc examples for the similar GaugeFunc. +func NewCounterFunc(opts CounterOpts, function func() float64) CounterFunc { + return newValueFunc(NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + nil, + opts.ConstLabels, + ), CounterValue, function) +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/desc.go b/vendor/github.com/prometheus/client_golang/prometheus/desc.go new file mode 100644 index 0000000000..deedc2dfbe --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/desc.go @@ -0,0 +1,199 @@ +// Copyright 2016 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "fmt" + "sort" + "strings" + + "github.com/cespare/xxhash/v2" + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/model" + "google.golang.org/protobuf/proto" + + "github.com/prometheus/client_golang/prometheus/internal" +) + +// Desc is the descriptor used by every Prometheus Metric. It is essentially +// the immutable meta-data of a Metric. The normal Metric implementations +// included in this package manage their Desc under the hood. Users only have to +// deal with Desc if they use advanced features like the ExpvarCollector or +// custom Collectors and Metrics. +// +// Descriptors registered with the same registry have to fulfill certain +// consistency and uniqueness criteria if they share the same fully-qualified +// name: They must have the same help string and the same label names (aka label +// dimensions) in each, constLabels and variableLabels, but they must differ in +// the values of the constLabels. +// +// Descriptors that share the same fully-qualified names and the same label +// values of their constLabels are considered equal. +// +// Use NewDesc to create new Desc instances. +type Desc struct { + // fqName has been built from Namespace, Subsystem, and Name. + fqName string + // help provides some helpful information about this metric. + help string + // constLabelPairs contains precalculated DTO label pairs based on + // the constant labels. + constLabelPairs []*dto.LabelPair + // variableLabels contains names of labels and normalization function for + // which the metric maintains variable values. + variableLabels ConstrainedLabels + // id is a hash of the values of the ConstLabels and fqName. This + // must be unique among all registered descriptors and can therefore be + // used as an identifier of the descriptor. + id uint64 + // dimHash is a hash of the label names (preset and variable) and the + // Help string. Each Desc with the same fqName must have the same + // dimHash. + dimHash uint64 + // err is an error that occurred during construction. It is reported on + // registration time. + err error +} + +// NewDesc allocates and initializes a new Desc. Errors are recorded in the Desc +// and will be reported on registration time. variableLabels and constLabels can +// be nil if no such labels should be set. fqName must not be empty. +// +// variableLabels only contain the label names. Their label values are variable +// and therefore not part of the Desc. (They are managed within the Metric.) +// +// For constLabels, the label values are constant. Therefore, they are fully +// specified in the Desc. See the Collector example for a usage pattern. +func NewDesc(fqName, help string, variableLabels []string, constLabels Labels) *Desc { + return V2.NewDesc(fqName, help, UnconstrainedLabels(variableLabels), constLabels) +} + +// NewDesc allocates and initializes a new Desc. Errors are recorded in the Desc +// and will be reported on registration time. variableLabels and constLabels can +// be nil if no such labels should be set. fqName must not be empty. +// +// variableLabels only contain the label names and normalization functions. Their +// label values are variable and therefore not part of the Desc. (They are managed +// within the Metric.) +// +// For constLabels, the label values are constant. Therefore, they are fully +// specified in the Desc. See the Collector example for a usage pattern. +func (v2) NewDesc(fqName, help string, variableLabels ConstrainableLabels, constLabels Labels) *Desc { + d := &Desc{ + fqName: fqName, + help: help, + variableLabels: variableLabels.constrainedLabels(), + } + if !model.IsValidMetricName(model.LabelValue(fqName)) { + d.err = fmt.Errorf("%q is not a valid metric name", fqName) + return d + } + // labelValues contains the label values of const labels (in order of + // their sorted label names) plus the fqName (at position 0). + labelValues := make([]string, 1, len(constLabels)+1) + labelValues[0] = fqName + labelNames := make([]string, 0, len(constLabels)+len(d.variableLabels)) + labelNameSet := map[string]struct{}{} + // First add only the const label names and sort them... + for labelName := range constLabels { + if !checkLabelName(labelName) { + d.err = fmt.Errorf("%q is not a valid label name for metric %q", labelName, fqName) + return d + } + labelNames = append(labelNames, labelName) + labelNameSet[labelName] = struct{}{} + } + sort.Strings(labelNames) + // ... so that we can now add const label values in the order of their names. + for _, labelName := range labelNames { + labelValues = append(labelValues, constLabels[labelName]) + } + // Validate the const label values. They can't have a wrong cardinality, so + // use in len(labelValues) as expectedNumberOfValues. + if err := validateLabelValues(labelValues, len(labelValues)); err != nil { + d.err = err + return d + } + // Now add the variable label names, but prefix them with something that + // cannot be in a regular label name. That prevents matching the label + // dimension with a different mix between preset and variable labels. + for _, label := range d.variableLabels { + if !checkLabelName(label.Name) { + d.err = fmt.Errorf("%q is not a valid label name for metric %q", label.Name, fqName) + return d + } + labelNames = append(labelNames, "$"+label.Name) + labelNameSet[label.Name] = struct{}{} + } + if len(labelNames) != len(labelNameSet) { + d.err = fmt.Errorf("duplicate label names in constant and variable labels for metric %q", fqName) + return d + } + + xxh := xxhash.New() + for _, val := range labelValues { + xxh.WriteString(val) + xxh.Write(separatorByteSlice) + } + d.id = xxh.Sum64() + // Sort labelNames so that order doesn't matter for the hash. + sort.Strings(labelNames) + // Now hash together (in this order) the help string and the sorted + // label names. + xxh.Reset() + xxh.WriteString(help) + xxh.Write(separatorByteSlice) + for _, labelName := range labelNames { + xxh.WriteString(labelName) + xxh.Write(separatorByteSlice) + } + d.dimHash = xxh.Sum64() + + d.constLabelPairs = make([]*dto.LabelPair, 0, len(constLabels)) + for n, v := range constLabels { + d.constLabelPairs = append(d.constLabelPairs, &dto.LabelPair{ + Name: proto.String(n), + Value: proto.String(v), + }) + } + sort.Sort(internal.LabelPairSorter(d.constLabelPairs)) + return d +} + +// NewInvalidDesc returns an invalid descriptor, i.e. a descriptor with the +// provided error set. If a collector returning such a descriptor is registered, +// registration will fail with the provided error. NewInvalidDesc can be used by +// a Collector to signal inability to describe itself. +func NewInvalidDesc(err error) *Desc { + return &Desc{ + err: err, + } +} + +func (d *Desc) String() string { + lpStrings := make([]string, 0, len(d.constLabelPairs)) + for _, lp := range d.constLabelPairs { + lpStrings = append( + lpStrings, + fmt.Sprintf("%s=%q", lp.GetName(), lp.GetValue()), + ) + } + return fmt.Sprintf( + "Desc{fqName: %q, help: %q, constLabels: {%s}, variableLabels: %v}", + d.fqName, + d.help, + strings.Join(lpStrings, ","), + d.variableLabels, + ) +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/expvar_collector.go b/vendor/github.com/prometheus/client_golang/prometheus/expvar_collector.go new file mode 100644 index 0000000000..c41ab37f3b --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/expvar_collector.go @@ -0,0 +1,86 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "encoding/json" + "expvar" +) + +type expvarCollector struct { + exports map[string]*Desc +} + +// NewExpvarCollector is the obsolete version of collectors.NewExpvarCollector. +// See there for documentation. +// +// Deprecated: Use collectors.NewExpvarCollector instead. +func NewExpvarCollector(exports map[string]*Desc) Collector { + return &expvarCollector{ + exports: exports, + } +} + +// Describe implements Collector. +func (e *expvarCollector) Describe(ch chan<- *Desc) { + for _, desc := range e.exports { + ch <- desc + } +} + +// Collect implements Collector. +func (e *expvarCollector) Collect(ch chan<- Metric) { + for name, desc := range e.exports { + var m Metric + expVar := expvar.Get(name) + if expVar == nil { + continue + } + var v interface{} + labels := make([]string, len(desc.variableLabels)) + if err := json.Unmarshal([]byte(expVar.String()), &v); err != nil { + ch <- NewInvalidMetric(desc, err) + continue + } + var processValue func(v interface{}, i int) + processValue = func(v interface{}, i int) { + if i >= len(labels) { + copiedLabels := append(make([]string, 0, len(labels)), labels...) + switch v := v.(type) { + case float64: + m = MustNewConstMetric(desc, UntypedValue, v, copiedLabels...) + case bool: + if v { + m = MustNewConstMetric(desc, UntypedValue, 1, copiedLabels...) + } else { + m = MustNewConstMetric(desc, UntypedValue, 0, copiedLabels...) + } + default: + return + } + ch <- m + return + } + vm, ok := v.(map[string]interface{}) + if !ok { + return + } + for lv, val := range vm { + labels[i] = lv + processValue(val, i+1) + } + } + processValue(v, 0) + } +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/gauge.go b/vendor/github.com/prometheus/client_golang/prometheus/gauge.go new file mode 100644 index 0000000000..f1ea6c76f7 --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/gauge.go @@ -0,0 +1,311 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "math" + "sync/atomic" + "time" + + dto "github.com/prometheus/client_model/go" +) + +// Gauge is a Metric that represents a single numerical value that can +// arbitrarily go up and down. +// +// A Gauge is typically used for measured values like temperatures or current +// memory usage, but also "counts" that can go up and down, like the number of +// running goroutines. +// +// To create Gauge instances, use NewGauge. +type Gauge interface { + Metric + Collector + + // Set sets the Gauge to an arbitrary value. + Set(float64) + // Inc increments the Gauge by 1. Use Add to increment it by arbitrary + // values. + Inc() + // Dec decrements the Gauge by 1. Use Sub to decrement it by arbitrary + // values. + Dec() + // Add adds the given value to the Gauge. (The value can be negative, + // resulting in a decrease of the Gauge.) + Add(float64) + // Sub subtracts the given value from the Gauge. (The value can be + // negative, resulting in an increase of the Gauge.) + Sub(float64) + + // SetToCurrentTime sets the Gauge to the current Unix time in seconds. + SetToCurrentTime() +} + +// GaugeOpts is an alias for Opts. See there for doc comments. +type GaugeOpts Opts + +// GaugeVecOpts bundles the options to create a GaugeVec metric. +// It is mandatory to set GaugeOpts, see there for mandatory fields. VariableLabels +// is optional and can safely be left to its default value. +type GaugeVecOpts struct { + GaugeOpts + + // VariableLabels are used to partition the metric vector by the given set + // of labels. Each label value will be constrained with the optional Contraint + // function, if provided. + VariableLabels ConstrainableLabels +} + +// NewGauge creates a new Gauge based on the provided GaugeOpts. +// +// The returned implementation is optimized for a fast Set method. If you have a +// choice for managing the value of a Gauge via Set vs. Inc/Dec/Add/Sub, pick +// the former. For example, the Inc method of the returned Gauge is slower than +// the Inc method of a Counter returned by NewCounter. This matches the typical +// scenarios for Gauges and Counters, where the former tends to be Set-heavy and +// the latter Inc-heavy. +func NewGauge(opts GaugeOpts) Gauge { + desc := NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + nil, + opts.ConstLabels, + ) + result := &gauge{desc: desc, labelPairs: desc.constLabelPairs} + result.init(result) // Init self-collection. + return result +} + +type gauge struct { + // valBits contains the bits of the represented float64 value. It has + // to go first in the struct to guarantee alignment for atomic + // operations. http://golang.org/pkg/sync/atomic/#pkg-note-BUG + valBits uint64 + + selfCollector + + desc *Desc + labelPairs []*dto.LabelPair +} + +func (g *gauge) Desc() *Desc { + return g.desc +} + +func (g *gauge) Set(val float64) { + atomic.StoreUint64(&g.valBits, math.Float64bits(val)) +} + +func (g *gauge) SetToCurrentTime() { + g.Set(float64(time.Now().UnixNano()) / 1e9) +} + +func (g *gauge) Inc() { + g.Add(1) +} + +func (g *gauge) Dec() { + g.Add(-1) +} + +func (g *gauge) Add(val float64) { + for { + oldBits := atomic.LoadUint64(&g.valBits) + newBits := math.Float64bits(math.Float64frombits(oldBits) + val) + if atomic.CompareAndSwapUint64(&g.valBits, oldBits, newBits) { + return + } + } +} + +func (g *gauge) Sub(val float64) { + g.Add(val * -1) +} + +func (g *gauge) Write(out *dto.Metric) error { + val := math.Float64frombits(atomic.LoadUint64(&g.valBits)) + return populateMetric(GaugeValue, val, g.labelPairs, nil, out) +} + +// GaugeVec is a Collector that bundles a set of Gauges that all share the same +// Desc, but have different values for their variable labels. This is used if +// you want to count the same thing partitioned by various dimensions +// (e.g. number of operations queued, partitioned by user and operation +// type). Create instances with NewGaugeVec. +type GaugeVec struct { + *MetricVec +} + +// NewGaugeVec creates a new GaugeVec based on the provided GaugeOpts and +// partitioned by the given label names. +func NewGaugeVec(opts GaugeOpts, labelNames []string) *GaugeVec { + return V2.NewGaugeVec(GaugeVecOpts{ + GaugeOpts: opts, + VariableLabels: UnconstrainedLabels(labelNames), + }) +} + +// NewGaugeVec creates a new GaugeVec based on the provided GaugeVecOpts. +func (v2) NewGaugeVec(opts GaugeVecOpts) *GaugeVec { + desc := V2.NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + opts.VariableLabels, + opts.ConstLabels, + ) + return &GaugeVec{ + MetricVec: NewMetricVec(desc, func(lvs ...string) Metric { + if len(lvs) != len(desc.variableLabels) { + panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels.labelNames(), lvs)) + } + result := &gauge{desc: desc, labelPairs: MakeLabelPairs(desc, lvs)} + result.init(result) // Init self-collection. + return result + }), + } +} + +// GetMetricWithLabelValues returns the Gauge for the given slice of label +// values (same order as the variable labels in Desc). If that combination of +// label values is accessed for the first time, a new Gauge is created. +// +// It is possible to call this method without using the returned Gauge to only +// create the new Gauge but leave it at its starting value 0. See also the +// SummaryVec example. +// +// Keeping the Gauge for later use is possible (and should be considered if +// performance is critical), but keep in mind that Reset, DeleteLabelValues and +// Delete can be used to delete the Gauge from the GaugeVec. In that case, the +// Gauge will still exist, but it will not be exported anymore, even if a +// Gauge with the same label values is created later. See also the CounterVec +// example. +// +// An error is returned if the number of label values is not the same as the +// number of variable labels in Desc (minus any curried labels). +// +// Note that for more than one label value, this method is prone to mistakes +// caused by an incorrect order of arguments. Consider GetMetricWith(Labels) as +// an alternative to avoid that type of mistake. For higher label numbers, the +// latter has a much more readable (albeit more verbose) syntax, but it comes +// with a performance overhead (for creating and processing the Labels map). +func (v *GaugeVec) GetMetricWithLabelValues(lvs ...string) (Gauge, error) { + metric, err := v.MetricVec.GetMetricWithLabelValues(lvs...) + if metric != nil { + return metric.(Gauge), err + } + return nil, err +} + +// GetMetricWith returns the Gauge for the given Labels map (the label names +// must match those of the variable labels in Desc). If that label map is +// accessed for the first time, a new Gauge is created. Implications of +// creating a Gauge without using it and keeping the Gauge for later use are +// the same as for GetMetricWithLabelValues. +// +// An error is returned if the number and names of the Labels are inconsistent +// with those of the variable labels in Desc (minus any curried labels). +// +// This method is used for the same purpose as +// GetMetricWithLabelValues(...string). See there for pros and cons of the two +// methods. +func (v *GaugeVec) GetMetricWith(labels Labels) (Gauge, error) { + metric, err := v.MetricVec.GetMetricWith(labels) + if metric != nil { + return metric.(Gauge), err + } + return nil, err +} + +// WithLabelValues works as GetMetricWithLabelValues, but panics where +// GetMetricWithLabelValues would have returned an error. Not returning an +// error allows shortcuts like +// +// myVec.WithLabelValues("404", "GET").Add(42) +func (v *GaugeVec) WithLabelValues(lvs ...string) Gauge { + g, err := v.GetMetricWithLabelValues(lvs...) + if err != nil { + panic(err) + } + return g +} + +// With works as GetMetricWith, but panics where GetMetricWithLabels would have +// returned an error. Not returning an error allows shortcuts like +// +// myVec.With(prometheus.Labels{"code": "404", "method": "GET"}).Add(42) +func (v *GaugeVec) With(labels Labels) Gauge { + g, err := v.GetMetricWith(labels) + if err != nil { + panic(err) + } + return g +} + +// CurryWith returns a vector curried with the provided labels, i.e. the +// returned vector has those labels pre-set for all labeled operations performed +// on it. The cardinality of the curried vector is reduced accordingly. The +// order of the remaining labels stays the same (just with the curried labels +// taken out of the sequence – which is relevant for the +// (GetMetric)WithLabelValues methods). It is possible to curry a curried +// vector, but only with labels not yet used for currying before. +// +// The metrics contained in the GaugeVec are shared between the curried and +// uncurried vectors. They are just accessed differently. Curried and uncurried +// vectors behave identically in terms of collection. Only one must be +// registered with a given registry (usually the uncurried version). The Reset +// method deletes all metrics, even if called on a curried vector. +func (v *GaugeVec) CurryWith(labels Labels) (*GaugeVec, error) { + vec, err := v.MetricVec.CurryWith(labels) + if vec != nil { + return &GaugeVec{vec}, err + } + return nil, err +} + +// MustCurryWith works as CurryWith but panics where CurryWith would have +// returned an error. +func (v *GaugeVec) MustCurryWith(labels Labels) *GaugeVec { + vec, err := v.CurryWith(labels) + if err != nil { + panic(err) + } + return vec +} + +// GaugeFunc is a Gauge whose value is determined at collect time by calling a +// provided function. +// +// To create GaugeFunc instances, use NewGaugeFunc. +type GaugeFunc interface { + Metric + Collector +} + +// NewGaugeFunc creates a new GaugeFunc based on the provided GaugeOpts. The +// value reported is determined by calling the given function from within the +// Write method. Take into account that metric collection may happen +// concurrently. Therefore, it must be safe to call the provided function +// concurrently. +// +// NewGaugeFunc is a good way to create an “info” style metric with a constant +// value of 1. Example: +// https://github.com/prometheus/common/blob/8558a5b7db3c84fa38b4766966059a7bd5bfa2ee/version/info.go#L36-L56 +func NewGaugeFunc(opts GaugeOpts, function func() float64) GaugeFunc { + return newValueFunc(NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + nil, + opts.ConstLabels, + ), GaugeValue, function) +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/histogram.go b/vendor/github.com/prometheus/client_golang/prometheus/histogram.go new file mode 100644 index 0000000000..8d818afe90 --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/histogram.go @@ -0,0 +1,1499 @@ +// Copyright 2015 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "fmt" + "math" + "runtime" + "sort" + "sync" + "sync/atomic" + "time" + + dto "github.com/prometheus/client_model/go" + + "google.golang.org/protobuf/proto" +) + +// nativeHistogramBounds for the frac of observed values. Only relevant for +// schema > 0. The position in the slice is the schema. (0 is never used, just +// here for convenience of using the schema directly as the index.) +// +// TODO(beorn7): Currently, we do a binary search into these slices. There are +// ways to turn it into a small number of simple array lookups. It probably only +// matters for schema 5 and beyond, but should be investigated. See this comment +// as a starting point: +// https://github.com/open-telemetry/opentelemetry-specification/issues/1776#issuecomment-870164310 +var nativeHistogramBounds = [][]float64{ + // Schema "0": + {0.5}, + // Schema 1: + {0.5, 0.7071067811865475}, + // Schema 2: + {0.5, 0.5946035575013605, 0.7071067811865475, 0.8408964152537144}, + // Schema 3: + { + 0.5, 0.5452538663326288, 0.5946035575013605, 0.6484197773255048, + 0.7071067811865475, 0.7711054127039704, 0.8408964152537144, 0.9170040432046711, + }, + // Schema 4: + { + 0.5, 0.5221368912137069, 0.5452538663326288, 0.5693943173783458, + 0.5946035575013605, 0.620928906036742, 0.6484197773255048, 0.6771277734684463, + 0.7071067811865475, 0.7384130729697496, 0.7711054127039704, 0.805245165974627, + 0.8408964152537144, 0.8781260801866495, 0.9170040432046711, 0.9576032806985735, + }, + // Schema 5: + { + 0.5, 0.5109485743270583, 0.5221368912137069, 0.5335702003384117, + 0.5452538663326288, 0.5571933712979462, 0.5693943173783458, 0.5818624293887887, + 0.5946035575013605, 0.6076236799902344, 0.620928906036742, 0.6345254785958666, + 0.6484197773255048, 0.6626183215798706, 0.6771277734684463, 0.6919549409819159, + 0.7071067811865475, 0.7225904034885232, 0.7384130729697496, 0.7545822137967112, + 0.7711054127039704, 0.7879904225539431, 0.805245165974627, 0.8228777390769823, + 0.8408964152537144, 0.8593096490612387, 0.8781260801866495, 0.8973545375015533, + 0.9170040432046711, 0.9370838170551498, 0.9576032806985735, 0.9785720620876999, + }, + // Schema 6: + { + 0.5, 0.5054446430258502, 0.5109485743270583, 0.5165124395106142, + 0.5221368912137069, 0.5278225891802786, 0.5335702003384117, 0.5393803988785598, + 0.5452538663326288, 0.5511912916539204, 0.5571933712979462, 0.5632608093041209, + 0.5693943173783458, 0.5755946149764913, 0.5818624293887887, 0.5881984958251406, + 0.5946035575013605, 0.6010783657263515, 0.6076236799902344, 0.6142402680534349, + 0.620928906036742, 0.6276903785123455, 0.6345254785958666, 0.6414350080393891, + 0.6484197773255048, 0.6554806057623822, 0.6626183215798706, 0.6698337620266515, + 0.6771277734684463, 0.6845012114872953, 0.6919549409819159, 0.6994898362691555, + 0.7071067811865475, 0.7148066691959849, 0.7225904034885232, 0.7304588970903234, + 0.7384130729697496, 0.7464538641456323, 0.7545822137967112, 0.762799075372269, + 0.7711054127039704, 0.7795022001189185, 0.7879904225539431, 0.7965710756711334, + 0.805245165974627, 0.8140137109286738, 0.8228777390769823, 0.8318382901633681, + 0.8408964152537144, 0.8500531768592616, 0.8593096490612387, 0.8686669176368529, + 0.8781260801866495, 0.8876882462632604, 0.8973545375015533, 0.9071260877501991, + 0.9170040432046711, 0.9269895625416926, 0.9370838170551498, 0.9472879907934827, + 0.9576032806985735, 0.9680308967461471, 0.9785720620876999, 0.9892280131939752, + }, + // Schema 7: + { + 0.5, 0.5027149505564014, 0.5054446430258502, 0.5081891574554764, + 0.5109485743270583, 0.5137229745593818, 0.5165124395106142, 0.5193170509806894, + 0.5221368912137069, 0.5249720429003435, 0.5278225891802786, 0.5306886136446309, + 0.5335702003384117, 0.5364674337629877, 0.5393803988785598, 0.5423091811066545, + 0.5452538663326288, 0.5482145409081883, 0.5511912916539204, 0.5541842058618393, + 0.5571933712979462, 0.5602188762048033, 0.5632608093041209, 0.5663192597993595, + 0.5693943173783458, 0.572486072215902, 0.5755946149764913, 0.5787200368168754, + 0.5818624293887887, 0.585021884841625, 0.5881984958251406, 0.5913923554921704, + 0.5946035575013605, 0.5978321960199137, 0.6010783657263515, 0.6043421618132907, + 0.6076236799902344, 0.6109230164863786, 0.6142402680534349, 0.6175755319684665, + 0.620928906036742, 0.6243004885946023, 0.6276903785123455, 0.6310986751971253, + 0.6345254785958666, 0.637970889198196, 0.6414350080393891, 0.6449179367033329, + 0.6484197773255048, 0.6519406325959679, 0.6554806057623822, 0.659039800633032, + 0.6626183215798706, 0.6662162735415805, 0.6698337620266515, 0.6734708931164728, + 0.6771277734684463, 0.6808045103191123, 0.6845012114872953, 0.688217985377265, + 0.6919549409819159, 0.6957121878859629, 0.6994898362691555, 0.7032879969095076, + 0.7071067811865475, 0.7109463010845827, 0.7148066691959849, 0.718687998724491, + 0.7225904034885232, 0.7265139979245261, 0.7304588970903234, 0.7344252166684908, + 0.7384130729697496, 0.7424225829363761, 0.7464538641456323, 0.7505070348132126, + 0.7545822137967112, 0.7586795205991071, 0.762799075372269, 0.7669409989204777, + 0.7711054127039704, 0.7752924388424999, 0.7795022001189185, 0.7837348199827764, + 0.7879904225539431, 0.7922691326262467, 0.7965710756711334, 0.8008963778413465, + 0.805245165974627, 0.8096175675974316, 0.8140137109286738, 0.8184337248834821, + 0.8228777390769823, 0.8273458838280969, 0.8318382901633681, 0.8363550898207981, + 0.8408964152537144, 0.8454623996346523, 0.8500531768592616, 0.8546688815502312, + 0.8593096490612387, 0.8639756154809185, 0.8686669176368529, 0.8733836930995842, + 0.8781260801866495, 0.8828942179666361, 0.8876882462632604, 0.8925083056594671, + 0.8973545375015533, 0.9022270839033115, 0.9071260877501991, 0.9120516927035263, + 0.9170040432046711, 0.9219832844793128, 0.9269895625416926, 0.9320230241988943, + 0.9370838170551498, 0.9421720895161669, 0.9472879907934827, 0.9524316709088368, + 0.9576032806985735, 0.9628029718180622, 0.9680308967461471, 0.9732872087896164, + 0.9785720620876999, 0.9838856116165875, 0.9892280131939752, 0.9945994234836328, + }, + // Schema 8: + { + 0.5, 0.5013556375251013, 0.5027149505564014, 0.5040779490592088, + 0.5054446430258502, 0.5068150424757447, 0.5081891574554764, 0.509566998038869, + 0.5109485743270583, 0.5123338964485679, 0.5137229745593818, 0.5151158188430205, + 0.5165124395106142, 0.5179128468009786, 0.5193170509806894, 0.520725062344158, + 0.5221368912137069, 0.5235525479396449, 0.5249720429003435, 0.526395386502313, + 0.5278225891802786, 0.5292536613972564, 0.5306886136446309, 0.5321274564422321, + 0.5335702003384117, 0.5350168559101208, 0.5364674337629877, 0.5379219445313954, + 0.5393803988785598, 0.5408428074966075, 0.5423091811066545, 0.5437795304588847, + 0.5452538663326288, 0.5467321995364429, 0.5482145409081883, 0.549700901315111, + 0.5511912916539204, 0.5526857228508706, 0.5541842058618393, 0.5556867516724088, + 0.5571933712979462, 0.5587040757836845, 0.5602188762048033, 0.5617377836665098, + 0.5632608093041209, 0.564787964283144, 0.5663192597993595, 0.5678547070789026, + 0.5693943173783458, 0.5709381019847808, 0.572486072215902, 0.5740382394200894, + 0.5755946149764913, 0.5771552102951081, 0.5787200368168754, 0.5802891060137493, + 0.5818624293887887, 0.5834400184762408, 0.585021884841625, 0.5866080400818185, + 0.5881984958251406, 0.5897932637314379, 0.5913923554921704, 0.5929957828304968, + 0.5946035575013605, 0.5962156912915756, 0.5978321960199137, 0.5994530835371903, + 0.6010783657263515, 0.6027080545025619, 0.6043421618132907, 0.6059806996384005, + 0.6076236799902344, 0.6092711149137041, 0.6109230164863786, 0.6125793968185725, + 0.6142402680534349, 0.6159056423670379, 0.6175755319684665, 0.6192499490999082, + 0.620928906036742, 0.622612415087629, 0.6243004885946023, 0.6259931389331581, + 0.6276903785123455, 0.6293922197748583, 0.6310986751971253, 0.6328097572894031, + 0.6345254785958666, 0.6362458516947014, 0.637970889198196, 0.6397006037528346, + 0.6414350080393891, 0.6431741147730128, 0.6449179367033329, 0.6466664866145447, + 0.6484197773255048, 0.6501778216898253, 0.6519406325959679, 0.6537082229673385, + 0.6554806057623822, 0.6572577939746774, 0.659039800633032, 0.6608266388015788, + 0.6626183215798706, 0.6644148621029772, 0.6662162735415805, 0.6680225691020727, + 0.6698337620266515, 0.6716498655934177, 0.6734708931164728, 0.6752968579460171, + 0.6771277734684463, 0.6789636531064505, 0.6808045103191123, 0.6826503586020058, + 0.6845012114872953, 0.6863570825438342, 0.688217985377265, 0.690083933630119, + 0.6919549409819159, 0.6938310211492645, 0.6957121878859629, 0.6975984549830999, + 0.6994898362691555, 0.7013863456101023, 0.7032879969095076, 0.7051948041086352, + 0.7071067811865475, 0.7090239421602076, 0.7109463010845827, 0.7128738720527471, + 0.7148066691959849, 0.7167447066838943, 0.718687998724491, 0.7206365595643126, + 0.7225904034885232, 0.7245495448210174, 0.7265139979245261, 0.7284837772007218, + 0.7304588970903234, 0.7324393720732029, 0.7344252166684908, 0.7364164454346837, + 0.7384130729697496, 0.7404151139112358, 0.7424225829363761, 0.7444354947621984, + 0.7464538641456323, 0.7484777058836176, 0.7505070348132126, 0.7525418658117031, + 0.7545822137967112, 0.7566280937263048, 0.7586795205991071, 0.7607365094544071, + 0.762799075372269, 0.7648672334736434, 0.7669409989204777, 0.7690203869158282, + 0.7711054127039704, 0.7731960915705107, 0.7752924388424999, 0.7773944698885442, + 0.7795022001189185, 0.7816156449856788, 0.7837348199827764, 0.7858597406461707, + 0.7879904225539431, 0.7901268813264122, 0.7922691326262467, 0.7944171921585818, + 0.7965710756711334, 0.7987307989543135, 0.8008963778413465, 0.8030678282083853, + 0.805245165974627, 0.8074284071024302, 0.8096175675974316, 0.8118126635086642, + 0.8140137109286738, 0.8162207259936375, 0.8184337248834821, 0.820652723822003, + 0.8228777390769823, 0.8251087869603088, 0.8273458838280969, 0.8295890460808079, + 0.8318382901633681, 0.8340936325652911, 0.8363550898207981, 0.8386226785089391, + 0.8408964152537144, 0.8431763167241966, 0.8454623996346523, 0.8477546807446661, + 0.8500531768592616, 0.8523579048290255, 0.8546688815502312, 0.8569861239649629, + 0.8593096490612387, 0.8616394738731368, 0.8639756154809185, 0.8663180910111553, + 0.8686669176368529, 0.871022112577578, 0.8733836930995842, 0.8757516765159389, + 0.8781260801866495, 0.8805069215187917, 0.8828942179666361, 0.8852879870317771, + 0.8876882462632604, 0.890095013257712, 0.8925083056594671, 0.8949281411607002, + 0.8973545375015533, 0.8997875124702672, 0.9022270839033115, 0.9046732696855155, + 0.9071260877501991, 0.909585556079304, 0.9120516927035263, 0.9145245157024483, + 0.9170040432046711, 0.9194902933879467, 0.9219832844793128, 0.9244830347552253, + 0.9269895625416926, 0.92950288621441, 0.9320230241988943, 0.9345499949706191, + 0.9370838170551498, 0.93962450902828, 0.9421720895161669, 0.9447265771954693, + 0.9472879907934827, 0.9498563490882775, 0.9524316709088368, 0.9550139751351947, + 0.9576032806985735, 0.9601996065815236, 0.9628029718180622, 0.9654133954938133, + 0.9680308967461471, 0.9706554947643201, 0.9732872087896164, 0.9759260581154889, + 0.9785720620876999, 0.9812252401044634, 0.9838856116165875, 0.9865531961276168, + 0.9892280131939752, 0.9919100824251095, 0.9945994234836328, 0.9972960560854698, + }, +} + +// The nativeHistogramBounds above can be generated with the code below. +// +// TODO(beorn7): It's tempting to actually use `go generate` to generate the +// code above. However, this could lead to slightly different numbers on +// different architectures. We still need to come to terms if we are fine with +// that, or if we might prefer to specify precise numbers in the standard. +// +// var nativeHistogramBounds [][]float64 = make([][]float64, 9) +// +// func init() { +// // Populate nativeHistogramBounds. +// numBuckets := 1 +// for i := range nativeHistogramBounds { +// bounds := []float64{0.5} +// factor := math.Exp2(math.Exp2(float64(-i))) +// for j := 0; j < numBuckets-1; j++ { +// var bound float64 +// if (j+1)%2 == 0 { +// // Use previously calculated value for increased precision. +// bound = nativeHistogramBounds[i-1][j/2+1] +// } else { +// bound = bounds[j] * factor +// } +// bounds = append(bounds, bound) +// } +// numBuckets *= 2 +// nativeHistogramBounds[i] = bounds +// } +// } + +// A Histogram counts individual observations from an event or sample stream in +// configurable static buckets (or in dynamic sparse buckets as part of the +// experimental Native Histograms, see below for more details). Similar to a +// Summary, it also provides a sum of observations and an observation count. +// +// On the Prometheus server, quantiles can be calculated from a Histogram using +// the histogram_quantile PromQL function. +// +// Note that Histograms, in contrast to Summaries, can be aggregated in PromQL +// (see the documentation for detailed procedures). However, Histograms require +// the user to pre-define suitable buckets, and they are in general less +// accurate. (Both problems are addressed by the experimental Native +// Histograms. To use them, configure a NativeHistogramBucketFactor in the +// HistogramOpts. They also require a Prometheus server v2.40+ with the +// corresponding feature flag enabled.) +// +// The Observe method of a Histogram has a very low performance overhead in +// comparison with the Observe method of a Summary. +// +// To create Histogram instances, use NewHistogram. +type Histogram interface { + Metric + Collector + + // Observe adds a single observation to the histogram. Observations are + // usually positive or zero. Negative observations are accepted but + // prevent current versions of Prometheus from properly detecting + // counter resets in the sum of observations. (The experimental Native + // Histograms handle negative observations properly.) See + // https://prometheus.io/docs/practices/histograms/#count-and-sum-of-observations + // for details. + Observe(float64) +} + +// bucketLabel is used for the label that defines the upper bound of a +// bucket of a histogram ("le" -> "less or equal"). +const bucketLabel = "le" + +// DefBuckets are the default Histogram buckets. The default buckets are +// tailored to broadly measure the response time (in seconds) of a network +// service. Most likely, however, you will be required to define buckets +// customized to your use case. +var DefBuckets = []float64{.005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10} + +// DefNativeHistogramZeroThreshold is the default value for +// NativeHistogramZeroThreshold in the HistogramOpts. +// +// The value is 2^-128 (or 0.5*2^-127 in the actual IEEE 754 representation), +// which is a bucket boundary at all possible resolutions. +const DefNativeHistogramZeroThreshold = 2.938735877055719e-39 + +// NativeHistogramZeroThresholdZero can be used as NativeHistogramZeroThreshold +// in the HistogramOpts to create a zero bucket of width zero, i.e. a zero +// bucket that only receives observations of precisely zero. +const NativeHistogramZeroThresholdZero = -1 + +var errBucketLabelNotAllowed = fmt.Errorf( + "%q is not allowed as label name in histograms", bucketLabel, +) + +// LinearBuckets creates 'count' regular buckets, each 'width' wide, where the +// lowest bucket has an upper bound of 'start'. The final +Inf bucket is not +// counted and not included in the returned slice. The returned slice is meant +// to be used for the Buckets field of HistogramOpts. +// +// The function panics if 'count' is zero or negative. +func LinearBuckets(start, width float64, count int) []float64 { + if count < 1 { + panic("LinearBuckets needs a positive count") + } + buckets := make([]float64, count) + for i := range buckets { + buckets[i] = start + start += width + } + return buckets +} + +// ExponentialBuckets creates 'count' regular buckets, where the lowest bucket +// has an upper bound of 'start' and each following bucket's upper bound is +// 'factor' times the previous bucket's upper bound. The final +Inf bucket is +// not counted and not included in the returned slice. The returned slice is +// meant to be used for the Buckets field of HistogramOpts. +// +// The function panics if 'count' is 0 or negative, if 'start' is 0 or negative, +// or if 'factor' is less than or equal 1. +func ExponentialBuckets(start, factor float64, count int) []float64 { + if count < 1 { + panic("ExponentialBuckets needs a positive count") + } + if start <= 0 { + panic("ExponentialBuckets needs a positive start value") + } + if factor <= 1 { + panic("ExponentialBuckets needs a factor greater than 1") + } + buckets := make([]float64, count) + for i := range buckets { + buckets[i] = start + start *= factor + } + return buckets +} + +// ExponentialBucketsRange creates 'count' buckets, where the lowest bucket is +// 'min' and the highest bucket is 'max'. The final +Inf bucket is not counted +// and not included in the returned slice. The returned slice is meant to be +// used for the Buckets field of HistogramOpts. +// +// The function panics if 'count' is 0 or negative, if 'min' is 0 or negative. +func ExponentialBucketsRange(min, max float64, count int) []float64 { + if count < 1 { + panic("ExponentialBucketsRange count needs a positive count") + } + if min <= 0 { + panic("ExponentialBucketsRange min needs to be greater than 0") + } + + // Formula for exponential buckets. + // max = min*growthFactor^(bucketCount-1) + + // We know max/min and highest bucket. Solve for growthFactor. + growthFactor := math.Pow(max/min, 1.0/float64(count-1)) + + // Now that we know growthFactor, solve for each bucket. + buckets := make([]float64, count) + for i := 1; i <= count; i++ { + buckets[i-1] = min * math.Pow(growthFactor, float64(i-1)) + } + return buckets +} + +// HistogramOpts bundles the options for creating a Histogram metric. It is +// mandatory to set Name to a non-empty string. All other fields are optional +// and can safely be left at their zero value, although it is strongly +// encouraged to set a Help string. +type HistogramOpts struct { + // Namespace, Subsystem, and Name are components of the fully-qualified + // name of the Histogram (created by joining these components with + // "_"). Only Name is mandatory, the others merely help structuring the + // name. Note that the fully-qualified name of the Histogram must be a + // valid Prometheus metric name. + Namespace string + Subsystem string + Name string + + // Help provides information about this Histogram. + // + // Metrics with the same fully-qualified name must have the same Help + // string. + Help string + + // ConstLabels are used to attach fixed labels to this metric. Metrics + // with the same fully-qualified name must have the same label names in + // their ConstLabels. + // + // ConstLabels are only used rarely. In particular, do not use them to + // attach the same labels to all your metrics. Those use cases are + // better covered by target labels set by the scraping Prometheus + // server, or by one specific metric (e.g. a build_info or a + // machine_role metric). See also + // https://prometheus.io/docs/instrumenting/writing_exporters/#target-labels-not-static-scraped-labels + ConstLabels Labels + + // Buckets defines the buckets into which observations are counted. Each + // element in the slice is the upper inclusive bound of a bucket. The + // values must be sorted in strictly increasing order. There is no need + // to add a highest bucket with +Inf bound, it will be added + // implicitly. If Buckets is left as nil or set to a slice of length + // zero, it is replaced by default buckets. The default buckets are + // DefBuckets if no buckets for a native histogram (see below) are used, + // otherwise the default is no buckets. (In other words, if you want to + // use both reguler buckets and buckets for a native histogram, you have + // to define the regular buckets here explicitly.) + Buckets []float64 + + // If NativeHistogramBucketFactor is greater than one, so-called sparse + // buckets are used (in addition to the regular buckets, if defined + // above). A Histogram with sparse buckets will be ingested as a Native + // Histogram by a Prometheus server with that feature enabled (requires + // Prometheus v2.40+). Sparse buckets are exponential buckets covering + // the whole float64 range (with the exception of the “zero” bucket, see + // NativeHistogramZeroThreshold below). From any one bucket to the next, + // the width of the bucket grows by a constant + // factor. NativeHistogramBucketFactor provides an upper bound for this + // factor (exception see below). The smaller + // NativeHistogramBucketFactor, the more buckets will be used and thus + // the more costly the histogram will become. A generally good trade-off + // between cost and accuracy is a value of 1.1 (each bucket is at most + // 10% wider than the previous one), which will result in each power of + // two divided into 8 buckets (e.g. there will be 8 buckets between 1 + // and 2, same as between 2 and 4, and 4 and 8, etc.). + // + // Details about the actually used factor: The factor is calculated as + // 2^(2^n), where n is an integer number between (and including) -8 and + // 4. n is chosen so that the resulting factor is the largest that is + // still smaller or equal to NativeHistogramBucketFactor. Note that the + // smallest possible factor is therefore approx. 1.00271 (i.e. 2^(2^-8) + // ). If NativeHistogramBucketFactor is greater than 1 but smaller than + // 2^(2^-8), then the actually used factor is still 2^(2^-8) even though + // it is larger than the provided NativeHistogramBucketFactor. + // + // NOTE: Native Histograms are still an experimental feature. Their + // behavior might still change without a major version + // bump. Subsequently, all NativeHistogram... options here might still + // change their behavior or name (or might completely disappear) without + // a major version bump. + NativeHistogramBucketFactor float64 + // All observations with an absolute value of less or equal + // NativeHistogramZeroThreshold are accumulated into a “zero” + // bucket. For best results, this should be close to a bucket + // boundary. This is usually the case if picking a power of two. If + // NativeHistogramZeroThreshold is left at zero, + // DefNativeHistogramZeroThreshold is used as the threshold. To configure + // a zero bucket with an actual threshold of zero (i.e. only + // observations of precisely zero will go into the zero bucket), set + // NativeHistogramZeroThreshold to the NativeHistogramZeroThresholdZero + // constant (or any negative float value). + NativeHistogramZeroThreshold float64 + + // The remaining fields define a strategy to limit the number of + // populated sparse buckets. If NativeHistogramMaxBucketNumber is left + // at zero, the number of buckets is not limited. (Note that this might + // lead to unbounded memory consumption if the values observed by the + // Histogram are sufficiently wide-spread. In particular, this could be + // used as a DoS attack vector. Where the observed values depend on + // external inputs, it is highly recommended to set a + // NativeHistogramMaxBucketNumber.) Once the set + // NativeHistogramMaxBucketNumber is exceeded, the following strategy is + // enacted: First, if the last reset (or the creation) of the histogram + // is at least NativeHistogramMinResetDuration ago, then the whole + // histogram is reset to its initial state (including regular + // buckets). If less time has passed, or if + // NativeHistogramMinResetDuration is zero, no reset is + // performed. Instead, the zero threshold is increased sufficiently to + // reduce the number of buckets to or below + // NativeHistogramMaxBucketNumber, but not to more than + // NativeHistogramMaxZeroThreshold. Thus, if + // NativeHistogramMaxZeroThreshold is already at or below the current + // zero threshold, nothing happens at this step. After that, if the + // number of buckets still exceeds NativeHistogramMaxBucketNumber, the + // resolution of the histogram is reduced by doubling the width of the + // sparse buckets (up to a growth factor between one bucket to the next + // of 2^(2^4) = 65536, see above). + NativeHistogramMaxBucketNumber uint32 + NativeHistogramMinResetDuration time.Duration + NativeHistogramMaxZeroThreshold float64 +} + +// HistogramVecOpts bundles the options to create a HistogramVec metric. +// It is mandatory to set HistogramOpts, see there for mandatory fields. VariableLabels +// is optional and can safely be left to its default value. +type HistogramVecOpts struct { + HistogramOpts + + // VariableLabels are used to partition the metric vector by the given set + // of labels. Each label value will be constrained with the optional Contraint + // function, if provided. + VariableLabels ConstrainableLabels +} + +// NewHistogram creates a new Histogram based on the provided HistogramOpts. It +// panics if the buckets in HistogramOpts are not in strictly increasing order. +// +// The returned implementation also implements ExemplarObserver. It is safe to +// perform the corresponding type assertion. Exemplars are tracked separately +// for each bucket. +func NewHistogram(opts HistogramOpts) Histogram { + return newHistogram( + NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + nil, + opts.ConstLabels, + ), + opts, + ) +} + +func newHistogram(desc *Desc, opts HistogramOpts, labelValues ...string) Histogram { + if len(desc.variableLabels) != len(labelValues) { + panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels.labelNames(), labelValues)) + } + + for _, n := range desc.variableLabels { + if n.Name == bucketLabel { + panic(errBucketLabelNotAllowed) + } + } + for _, lp := range desc.constLabelPairs { + if lp.GetName() == bucketLabel { + panic(errBucketLabelNotAllowed) + } + } + + h := &histogram{ + desc: desc, + upperBounds: opts.Buckets, + labelPairs: MakeLabelPairs(desc, labelValues), + nativeHistogramMaxBuckets: opts.NativeHistogramMaxBucketNumber, + nativeHistogramMaxZeroThreshold: opts.NativeHistogramMaxZeroThreshold, + nativeHistogramMinResetDuration: opts.NativeHistogramMinResetDuration, + lastResetTime: time.Now(), + now: time.Now, + } + if len(h.upperBounds) == 0 && opts.NativeHistogramBucketFactor <= 1 { + h.upperBounds = DefBuckets + } + if opts.NativeHistogramBucketFactor <= 1 { + h.nativeHistogramSchema = math.MinInt32 // To mark that there are no sparse buckets. + } else { + switch { + case opts.NativeHistogramZeroThreshold > 0: + h.nativeHistogramZeroThreshold = opts.NativeHistogramZeroThreshold + case opts.NativeHistogramZeroThreshold == 0: + h.nativeHistogramZeroThreshold = DefNativeHistogramZeroThreshold + } // Leave h.nativeHistogramZeroThreshold at 0 otherwise. + h.nativeHistogramSchema = pickSchema(opts.NativeHistogramBucketFactor) + } + for i, upperBound := range h.upperBounds { + if i < len(h.upperBounds)-1 { + if upperBound >= h.upperBounds[i+1] { + panic(fmt.Errorf( + "histogram buckets must be in increasing order: %f >= %f", + upperBound, h.upperBounds[i+1], + )) + } + } else { + if math.IsInf(upperBound, +1) { + // The +Inf bucket is implicit. Remove it here. + h.upperBounds = h.upperBounds[:i] + } + } + } + // Finally we know the final length of h.upperBounds and can make buckets + // for both counts as well as exemplars: + h.counts[0] = &histogramCounts{buckets: make([]uint64, len(h.upperBounds))} + atomic.StoreUint64(&h.counts[0].nativeHistogramZeroThresholdBits, math.Float64bits(h.nativeHistogramZeroThreshold)) + atomic.StoreInt32(&h.counts[0].nativeHistogramSchema, h.nativeHistogramSchema) + h.counts[1] = &histogramCounts{buckets: make([]uint64, len(h.upperBounds))} + atomic.StoreUint64(&h.counts[1].nativeHistogramZeroThresholdBits, math.Float64bits(h.nativeHistogramZeroThreshold)) + atomic.StoreInt32(&h.counts[1].nativeHistogramSchema, h.nativeHistogramSchema) + h.exemplars = make([]atomic.Value, len(h.upperBounds)+1) + + h.init(h) // Init self-collection. + return h +} + +type histogramCounts struct { + // Order in this struct matters for the alignment required by atomic + // operations, see http://golang.org/pkg/sync/atomic/#pkg-note-BUG + + // sumBits contains the bits of the float64 representing the sum of all + // observations. + sumBits uint64 + count uint64 + + // nativeHistogramZeroBucket counts all (positive and negative) + // observations in the zero bucket (with an absolute value less or equal + // the current threshold, see next field. + nativeHistogramZeroBucket uint64 + // nativeHistogramZeroThresholdBits is the bit pattern of the current + // threshold for the zero bucket. It's initially equal to + // nativeHistogramZeroThreshold but may change according to the bucket + // count limitation strategy. + nativeHistogramZeroThresholdBits uint64 + // nativeHistogramSchema may change over time according to the bucket + // count limitation strategy and therefore has to be saved here. + nativeHistogramSchema int32 + // Number of (positive and negative) sparse buckets. + nativeHistogramBucketsNumber uint32 + + // Regular buckets. + buckets []uint64 + + // The sparse buckets for native histograms are implemented with a + // sync.Map for now. A dedicated data structure will likely be more + // efficient. There are separate maps for negative and positive + // observations. The map's value is an *int64, counting observations in + // that bucket. (Note that we don't use uint64 as an int64 won't + // overflow in practice, and working with signed numbers from the + // beginning simplifies the handling of deltas.) The map's key is the + // index of the bucket according to the used + // nativeHistogramSchema. Index 0 is for an upper bound of 1. + nativeHistogramBucketsPositive, nativeHistogramBucketsNegative sync.Map +} + +// observe manages the parts of observe that only affects +// histogramCounts. doSparse is true if sparse buckets should be done, +// too. +func (hc *histogramCounts) observe(v float64, bucket int, doSparse bool) { + if bucket < len(hc.buckets) { + atomic.AddUint64(&hc.buckets[bucket], 1) + } + atomicAddFloat(&hc.sumBits, v) + if doSparse && !math.IsNaN(v) { + var ( + key int + schema = atomic.LoadInt32(&hc.nativeHistogramSchema) + zeroThreshold = math.Float64frombits(atomic.LoadUint64(&hc.nativeHistogramZeroThresholdBits)) + bucketCreated, isInf bool + ) + if math.IsInf(v, 0) { + // Pretend v is MaxFloat64 but later increment key by one. + if math.IsInf(v, +1) { + v = math.MaxFloat64 + } else { + v = -math.MaxFloat64 + } + isInf = true + } + frac, exp := math.Frexp(math.Abs(v)) + if schema > 0 { + bounds := nativeHistogramBounds[schema] + key = sort.SearchFloat64s(bounds, frac) + (exp-1)*len(bounds) + } else { + key = exp + if frac == 0.5 { + key-- + } + offset := (1 << -schema) - 1 + key = (key + offset) >> -schema + } + if isInf { + key++ + } + switch { + case v > zeroThreshold: + bucketCreated = addToBucket(&hc.nativeHistogramBucketsPositive, key, 1) + case v < -zeroThreshold: + bucketCreated = addToBucket(&hc.nativeHistogramBucketsNegative, key, 1) + default: + atomic.AddUint64(&hc.nativeHistogramZeroBucket, 1) + } + if bucketCreated { + atomic.AddUint32(&hc.nativeHistogramBucketsNumber, 1) + } + } + // Increment count last as we take it as a signal that the observation + // is complete. + atomic.AddUint64(&hc.count, 1) +} + +type histogram struct { + // countAndHotIdx enables lock-free writes with use of atomic updates. + // The most significant bit is the hot index [0 or 1] of the count field + // below. Observe calls update the hot one. All remaining bits count the + // number of Observe calls. Observe starts by incrementing this counter, + // and finish by incrementing the count field in the respective + // histogramCounts, as a marker for completion. + // + // Calls of the Write method (which are non-mutating reads from the + // perspective of the histogram) swap the hot–cold under the writeMtx + // lock. A cooldown is awaited (while locked) by comparing the number of + // observations with the initiation count. Once they match, then the + // last observation on the now cool one has completed. All cold fields must + // be merged into the new hot before releasing writeMtx. + // + // Fields with atomic access first! See alignment constraint: + // http://golang.org/pkg/sync/atomic/#pkg-note-BUG + countAndHotIdx uint64 + + selfCollector + desc *Desc + + // Only used in the Write method and for sparse bucket management. + mtx sync.Mutex + + // Two counts, one is "hot" for lock-free observations, the other is + // "cold" for writing out a dto.Metric. It has to be an array of + // pointers to guarantee 64bit alignment of the histogramCounts, see + // http://golang.org/pkg/sync/atomic/#pkg-note-BUG. + counts [2]*histogramCounts + + upperBounds []float64 + labelPairs []*dto.LabelPair + exemplars []atomic.Value // One more than buckets (to include +Inf), each a *dto.Exemplar. + nativeHistogramSchema int32 // The initial schema. Set to math.MinInt32 if no sparse buckets are used. + nativeHistogramZeroThreshold float64 // The initial zero threshold. + nativeHistogramMaxZeroThreshold float64 + nativeHistogramMaxBuckets uint32 + nativeHistogramMinResetDuration time.Duration + lastResetTime time.Time // Protected by mtx. + + now func() time.Time // To mock out time.Now() for testing. +} + +func (h *histogram) Desc() *Desc { + return h.desc +} + +func (h *histogram) Observe(v float64) { + h.observe(v, h.findBucket(v)) +} + +func (h *histogram) ObserveWithExemplar(v float64, e Labels) { + i := h.findBucket(v) + h.observe(v, i) + h.updateExemplar(v, i, e) +} + +func (h *histogram) Write(out *dto.Metric) error { + // For simplicity, we protect this whole method by a mutex. It is not in + // the hot path, i.e. Observe is called much more often than Write. The + // complication of making Write lock-free isn't worth it, if possible at + // all. + h.mtx.Lock() + defer h.mtx.Unlock() + + // Adding 1<<63 switches the hot index (from 0 to 1 or from 1 to 0) + // without touching the count bits. See the struct comments for a full + // description of the algorithm. + n := atomic.AddUint64(&h.countAndHotIdx, 1<<63) + // count is contained unchanged in the lower 63 bits. + count := n & ((1 << 63) - 1) + // The most significant bit tells us which counts is hot. The complement + // is thus the cold one. + hotCounts := h.counts[n>>63] + coldCounts := h.counts[(^n)>>63] + + waitForCooldown(count, coldCounts) + + his := &dto.Histogram{ + Bucket: make([]*dto.Bucket, len(h.upperBounds)), + SampleCount: proto.Uint64(count), + SampleSum: proto.Float64(math.Float64frombits(atomic.LoadUint64(&coldCounts.sumBits))), + } + out.Histogram = his + out.Label = h.labelPairs + + var cumCount uint64 + for i, upperBound := range h.upperBounds { + cumCount += atomic.LoadUint64(&coldCounts.buckets[i]) + his.Bucket[i] = &dto.Bucket{ + CumulativeCount: proto.Uint64(cumCount), + UpperBound: proto.Float64(upperBound), + } + if e := h.exemplars[i].Load(); e != nil { + his.Bucket[i].Exemplar = e.(*dto.Exemplar) + } + } + // If there is an exemplar for the +Inf bucket, we have to add that bucket explicitly. + if e := h.exemplars[len(h.upperBounds)].Load(); e != nil { + b := &dto.Bucket{ + CumulativeCount: proto.Uint64(count), + UpperBound: proto.Float64(math.Inf(1)), + Exemplar: e.(*dto.Exemplar), + } + his.Bucket = append(his.Bucket, b) + } + if h.nativeHistogramSchema > math.MinInt32 { + his.ZeroThreshold = proto.Float64(math.Float64frombits(atomic.LoadUint64(&coldCounts.nativeHistogramZeroThresholdBits))) + his.Schema = proto.Int32(atomic.LoadInt32(&coldCounts.nativeHistogramSchema)) + zeroBucket := atomic.LoadUint64(&coldCounts.nativeHistogramZeroBucket) + + defer func() { + coldCounts.nativeHistogramBucketsPositive.Range(addAndReset(&hotCounts.nativeHistogramBucketsPositive, &hotCounts.nativeHistogramBucketsNumber)) + coldCounts.nativeHistogramBucketsNegative.Range(addAndReset(&hotCounts.nativeHistogramBucketsNegative, &hotCounts.nativeHistogramBucketsNumber)) + }() + + his.ZeroCount = proto.Uint64(zeroBucket) + his.NegativeSpan, his.NegativeDelta = makeBuckets(&coldCounts.nativeHistogramBucketsNegative) + his.PositiveSpan, his.PositiveDelta = makeBuckets(&coldCounts.nativeHistogramBucketsPositive) + } + addAndResetCounts(hotCounts, coldCounts) + return nil +} + +// findBucket returns the index of the bucket for the provided value, or +// len(h.upperBounds) for the +Inf bucket. +func (h *histogram) findBucket(v float64) int { + // TODO(beorn7): For small numbers of buckets (<30), a linear search is + // slightly faster than the binary search. If we really care, we could + // switch from one search strategy to the other depending on the number + // of buckets. + // + // Microbenchmarks (BenchmarkHistogramNoLabels): + // 11 buckets: 38.3 ns/op linear - binary 48.7 ns/op + // 100 buckets: 78.1 ns/op linear - binary 54.9 ns/op + // 300 buckets: 154 ns/op linear - binary 61.6 ns/op + return sort.SearchFloat64s(h.upperBounds, v) +} + +// observe is the implementation for Observe without the findBucket part. +func (h *histogram) observe(v float64, bucket int) { + // Do not add to sparse buckets for NaN observations. + doSparse := h.nativeHistogramSchema > math.MinInt32 && !math.IsNaN(v) + // We increment h.countAndHotIdx so that the counter in the lower + // 63 bits gets incremented. At the same time, we get the new value + // back, which we can use to find the currently-hot counts. + n := atomic.AddUint64(&h.countAndHotIdx, 1) + hotCounts := h.counts[n>>63] + hotCounts.observe(v, bucket, doSparse) + if doSparse { + h.limitBuckets(hotCounts, v, bucket) + } +} + +// limitBuckets applies a strategy to limit the number of populated sparse +// buckets. It's generally best effort, and there are situations where the +// number can go higher (if even the lowest resolution isn't enough to reduce +// the number sufficiently, or if the provided counts aren't fully updated yet +// by a concurrently happening Write call). +func (h *histogram) limitBuckets(counts *histogramCounts, value float64, bucket int) { + if h.nativeHistogramMaxBuckets == 0 { + return // No limit configured. + } + if h.nativeHistogramMaxBuckets >= atomic.LoadUint32(&counts.nativeHistogramBucketsNumber) { + return // Bucket limit not exceeded yet. + } + + h.mtx.Lock() + defer h.mtx.Unlock() + + // The hot counts might have been swapped just before we acquired the + // lock. Re-fetch the hot counts first... + n := atomic.LoadUint64(&h.countAndHotIdx) + hotIdx := n >> 63 + coldIdx := (^n) >> 63 + hotCounts := h.counts[hotIdx] + coldCounts := h.counts[coldIdx] + // ...and then check again if we really have to reduce the bucket count. + if h.nativeHistogramMaxBuckets >= atomic.LoadUint32(&hotCounts.nativeHistogramBucketsNumber) { + return // Bucket limit not exceeded after all. + } + // Try the various strategies in order. + if h.maybeReset(hotCounts, coldCounts, coldIdx, value, bucket) { + return + } + if h.maybeWidenZeroBucket(hotCounts, coldCounts) { + return + } + h.doubleBucketWidth(hotCounts, coldCounts) +} + +// maybeReset resests the whole histogram if at least h.nativeHistogramMinResetDuration +// has been passed. It returns true if the histogram has been reset. The caller +// must have locked h.mtx. +func (h *histogram) maybeReset(hot, cold *histogramCounts, coldIdx uint64, value float64, bucket int) bool { + // We are using the possibly mocked h.now() rather than + // time.Since(h.lastResetTime) to enable testing. + if h.nativeHistogramMinResetDuration == 0 || h.now().Sub(h.lastResetTime) < h.nativeHistogramMinResetDuration { + return false + } + // Completely reset coldCounts. + h.resetCounts(cold) + // Repeat the latest observation to not lose it completely. + cold.observe(value, bucket, true) + // Make coldCounts the new hot counts while ressetting countAndHotIdx. + n := atomic.SwapUint64(&h.countAndHotIdx, (coldIdx<<63)+1) + count := n & ((1 << 63) - 1) + waitForCooldown(count, hot) + // Finally, reset the formerly hot counts, too. + h.resetCounts(hot) + h.lastResetTime = h.now() + return true +} + +// maybeWidenZeroBucket widens the zero bucket until it includes the existing +// buckets closest to the zero bucket (which could be two, if an equidistant +// negative and a positive bucket exists, but usually it's only one bucket to be +// merged into the new wider zero bucket). h.nativeHistogramMaxZeroThreshold +// limits how far the zero bucket can be extended, and if that's not enough to +// include an existing bucket, the method returns false. The caller must have +// locked h.mtx. +func (h *histogram) maybeWidenZeroBucket(hot, cold *histogramCounts) bool { + currentZeroThreshold := math.Float64frombits(atomic.LoadUint64(&hot.nativeHistogramZeroThresholdBits)) + if currentZeroThreshold >= h.nativeHistogramMaxZeroThreshold { + return false + } + // Find the key of the bucket closest to zero. + smallestKey := findSmallestKey(&hot.nativeHistogramBucketsPositive) + smallestNegativeKey := findSmallestKey(&hot.nativeHistogramBucketsNegative) + if smallestNegativeKey < smallestKey { + smallestKey = smallestNegativeKey + } + if smallestKey == math.MaxInt32 { + return false + } + newZeroThreshold := getLe(smallestKey, atomic.LoadInt32(&hot.nativeHistogramSchema)) + if newZeroThreshold > h.nativeHistogramMaxZeroThreshold { + return false // New threshold would exceed the max threshold. + } + atomic.StoreUint64(&cold.nativeHistogramZeroThresholdBits, math.Float64bits(newZeroThreshold)) + // Remove applicable buckets. + if _, loaded := cold.nativeHistogramBucketsNegative.LoadAndDelete(smallestKey); loaded { + atomicDecUint32(&cold.nativeHistogramBucketsNumber) + } + if _, loaded := cold.nativeHistogramBucketsPositive.LoadAndDelete(smallestKey); loaded { + atomicDecUint32(&cold.nativeHistogramBucketsNumber) + } + // Make cold counts the new hot counts. + n := atomic.AddUint64(&h.countAndHotIdx, 1<<63) + count := n & ((1 << 63) - 1) + // Swap the pointer names to represent the new roles and make + // the rest less confusing. + hot, cold = cold, hot + waitForCooldown(count, cold) + // Add all the now cold counts to the new hot counts... + addAndResetCounts(hot, cold) + // ...adjust the new zero threshold in the cold counts, too... + atomic.StoreUint64(&cold.nativeHistogramZeroThresholdBits, math.Float64bits(newZeroThreshold)) + // ...and then merge the newly deleted buckets into the wider zero + // bucket. + mergeAndDeleteOrAddAndReset := func(hotBuckets, coldBuckets *sync.Map) func(k, v interface{}) bool { + return func(k, v interface{}) bool { + key := k.(int) + bucket := v.(*int64) + if key == smallestKey { + // Merge into hot zero bucket... + atomic.AddUint64(&hot.nativeHistogramZeroBucket, uint64(atomic.LoadInt64(bucket))) + // ...and delete from cold counts. + coldBuckets.Delete(key) + atomicDecUint32(&cold.nativeHistogramBucketsNumber) + } else { + // Add to corresponding hot bucket... + if addToBucket(hotBuckets, key, atomic.LoadInt64(bucket)) { + atomic.AddUint32(&hot.nativeHistogramBucketsNumber, 1) + } + // ...and reset cold bucket. + atomic.StoreInt64(bucket, 0) + } + return true + } + } + + cold.nativeHistogramBucketsPositive.Range(mergeAndDeleteOrAddAndReset(&hot.nativeHistogramBucketsPositive, &cold.nativeHistogramBucketsPositive)) + cold.nativeHistogramBucketsNegative.Range(mergeAndDeleteOrAddAndReset(&hot.nativeHistogramBucketsNegative, &cold.nativeHistogramBucketsNegative)) + return true +} + +// doubleBucketWidth doubles the bucket width (by decrementing the schema +// number). Note that very sparse buckets could lead to a low reduction of the +// bucket count (or even no reduction at all). The method does nothing if the +// schema is already -4. +func (h *histogram) doubleBucketWidth(hot, cold *histogramCounts) { + coldSchema := atomic.LoadInt32(&cold.nativeHistogramSchema) + if coldSchema == -4 { + return // Already at lowest resolution. + } + coldSchema-- + atomic.StoreInt32(&cold.nativeHistogramSchema, coldSchema) + // Play it simple and just delete all cold buckets. + atomic.StoreUint32(&cold.nativeHistogramBucketsNumber, 0) + deleteSyncMap(&cold.nativeHistogramBucketsNegative) + deleteSyncMap(&cold.nativeHistogramBucketsPositive) + // Make coldCounts the new hot counts. + n := atomic.AddUint64(&h.countAndHotIdx, 1<<63) + count := n & ((1 << 63) - 1) + // Swap the pointer names to represent the new roles and make + // the rest less confusing. + hot, cold = cold, hot + waitForCooldown(count, cold) + // Add all the now cold counts to the new hot counts... + addAndResetCounts(hot, cold) + // ...adjust the schema in the cold counts, too... + atomic.StoreInt32(&cold.nativeHistogramSchema, coldSchema) + // ...and then merge the cold buckets into the wider hot buckets. + merge := func(hotBuckets *sync.Map) func(k, v interface{}) bool { + return func(k, v interface{}) bool { + key := k.(int) + bucket := v.(*int64) + // Adjust key to match the bucket to merge into. + if key > 0 { + key++ + } + key /= 2 + // Add to corresponding hot bucket. + if addToBucket(hotBuckets, key, atomic.LoadInt64(bucket)) { + atomic.AddUint32(&hot.nativeHistogramBucketsNumber, 1) + } + return true + } + } + + cold.nativeHistogramBucketsPositive.Range(merge(&hot.nativeHistogramBucketsPositive)) + cold.nativeHistogramBucketsNegative.Range(merge(&hot.nativeHistogramBucketsNegative)) + // Play it simple again and just delete all cold buckets. + atomic.StoreUint32(&cold.nativeHistogramBucketsNumber, 0) + deleteSyncMap(&cold.nativeHistogramBucketsNegative) + deleteSyncMap(&cold.nativeHistogramBucketsPositive) +} + +func (h *histogram) resetCounts(counts *histogramCounts) { + atomic.StoreUint64(&counts.sumBits, 0) + atomic.StoreUint64(&counts.count, 0) + atomic.StoreUint64(&counts.nativeHistogramZeroBucket, 0) + atomic.StoreUint64(&counts.nativeHistogramZeroThresholdBits, math.Float64bits(h.nativeHistogramZeroThreshold)) + atomic.StoreInt32(&counts.nativeHistogramSchema, h.nativeHistogramSchema) + atomic.StoreUint32(&counts.nativeHistogramBucketsNumber, 0) + for i := range h.upperBounds { + atomic.StoreUint64(&counts.buckets[i], 0) + } + deleteSyncMap(&counts.nativeHistogramBucketsNegative) + deleteSyncMap(&counts.nativeHistogramBucketsPositive) +} + +// updateExemplar replaces the exemplar for the provided bucket. With empty +// labels, it's a no-op. It panics if any of the labels is invalid. +func (h *histogram) updateExemplar(v float64, bucket int, l Labels) { + if l == nil { + return + } + e, err := newExemplar(v, h.now(), l) + if err != nil { + panic(err) + } + h.exemplars[bucket].Store(e) +} + +// HistogramVec is a Collector that bundles a set of Histograms that all share the +// same Desc, but have different values for their variable labels. This is used +// if you want to count the same thing partitioned by various dimensions +// (e.g. HTTP request latencies, partitioned by status code and method). Create +// instances with NewHistogramVec. +type HistogramVec struct { + *MetricVec +} + +// NewHistogramVec creates a new HistogramVec based on the provided HistogramOpts and +// partitioned by the given label names. +func NewHistogramVec(opts HistogramOpts, labelNames []string) *HistogramVec { + return V2.NewHistogramVec(HistogramVecOpts{ + HistogramOpts: opts, + VariableLabels: UnconstrainedLabels(labelNames), + }) +} + +// NewHistogramVec creates a new HistogramVec based on the provided HistogramVecOpts. +func (v2) NewHistogramVec(opts HistogramVecOpts) *HistogramVec { + desc := V2.NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + opts.VariableLabels, + opts.ConstLabels, + ) + return &HistogramVec{ + MetricVec: NewMetricVec(desc, func(lvs ...string) Metric { + return newHistogram(desc, opts.HistogramOpts, lvs...) + }), + } +} + +// GetMetricWithLabelValues returns the Histogram for the given slice of label +// values (same order as the variable labels in Desc). If that combination of +// label values is accessed for the first time, a new Histogram is created. +// +// It is possible to call this method without using the returned Histogram to only +// create the new Histogram but leave it at its starting value, a Histogram without +// any observations. +// +// Keeping the Histogram for later use is possible (and should be considered if +// performance is critical), but keep in mind that Reset, DeleteLabelValues and +// Delete can be used to delete the Histogram from the HistogramVec. In that case, the +// Histogram will still exist, but it will not be exported anymore, even if a +// Histogram with the same label values is created later. See also the CounterVec +// example. +// +// An error is returned if the number of label values is not the same as the +// number of variable labels in Desc (minus any curried labels). +// +// Note that for more than one label value, this method is prone to mistakes +// caused by an incorrect order of arguments. Consider GetMetricWith(Labels) as +// an alternative to avoid that type of mistake. For higher label numbers, the +// latter has a much more readable (albeit more verbose) syntax, but it comes +// with a performance overhead (for creating and processing the Labels map). +// See also the GaugeVec example. +func (v *HistogramVec) GetMetricWithLabelValues(lvs ...string) (Observer, error) { + metric, err := v.MetricVec.GetMetricWithLabelValues(lvs...) + if metric != nil { + return metric.(Observer), err + } + return nil, err +} + +// GetMetricWith returns the Histogram for the given Labels map (the label names +// must match those of the variable labels in Desc). If that label map is +// accessed for the first time, a new Histogram is created. Implications of +// creating a Histogram without using it and keeping the Histogram for later use +// are the same as for GetMetricWithLabelValues. +// +// An error is returned if the number and names of the Labels are inconsistent +// with those of the variable labels in Desc (minus any curried labels). +// +// This method is used for the same purpose as +// GetMetricWithLabelValues(...string). See there for pros and cons of the two +// methods. +func (v *HistogramVec) GetMetricWith(labels Labels) (Observer, error) { + metric, err := v.MetricVec.GetMetricWith(labels) + if metric != nil { + return metric.(Observer), err + } + return nil, err +} + +// WithLabelValues works as GetMetricWithLabelValues, but panics where +// GetMetricWithLabelValues would have returned an error. Not returning an +// error allows shortcuts like +// +// myVec.WithLabelValues("404", "GET").Observe(42.21) +func (v *HistogramVec) WithLabelValues(lvs ...string) Observer { + h, err := v.GetMetricWithLabelValues(lvs...) + if err != nil { + panic(err) + } + return h +} + +// With works as GetMetricWith but panics where GetMetricWithLabels would have +// returned an error. Not returning an error allows shortcuts like +// +// myVec.With(prometheus.Labels{"code": "404", "method": "GET"}).Observe(42.21) +func (v *HistogramVec) With(labels Labels) Observer { + h, err := v.GetMetricWith(labels) + if err != nil { + panic(err) + } + return h +} + +// CurryWith returns a vector curried with the provided labels, i.e. the +// returned vector has those labels pre-set for all labeled operations performed +// on it. The cardinality of the curried vector is reduced accordingly. The +// order of the remaining labels stays the same (just with the curried labels +// taken out of the sequence – which is relevant for the +// (GetMetric)WithLabelValues methods). It is possible to curry a curried +// vector, but only with labels not yet used for currying before. +// +// The metrics contained in the HistogramVec are shared between the curried and +// uncurried vectors. They are just accessed differently. Curried and uncurried +// vectors behave identically in terms of collection. Only one must be +// registered with a given registry (usually the uncurried version). The Reset +// method deletes all metrics, even if called on a curried vector. +func (v *HistogramVec) CurryWith(labels Labels) (ObserverVec, error) { + vec, err := v.MetricVec.CurryWith(labels) + if vec != nil { + return &HistogramVec{vec}, err + } + return nil, err +} + +// MustCurryWith works as CurryWith but panics where CurryWith would have +// returned an error. +func (v *HistogramVec) MustCurryWith(labels Labels) ObserverVec { + vec, err := v.CurryWith(labels) + if err != nil { + panic(err) + } + return vec +} + +type constHistogram struct { + desc *Desc + count uint64 + sum float64 + buckets map[float64]uint64 + labelPairs []*dto.LabelPair +} + +func (h *constHistogram) Desc() *Desc { + return h.desc +} + +func (h *constHistogram) Write(out *dto.Metric) error { + his := &dto.Histogram{} + + buckets := make([]*dto.Bucket, 0, len(h.buckets)) + + his.SampleCount = proto.Uint64(h.count) + his.SampleSum = proto.Float64(h.sum) + for upperBound, count := range h.buckets { + buckets = append(buckets, &dto.Bucket{ + CumulativeCount: proto.Uint64(count), + UpperBound: proto.Float64(upperBound), + }) + } + + if len(buckets) > 0 { + sort.Sort(buckSort(buckets)) + } + his.Bucket = buckets + + out.Histogram = his + out.Label = h.labelPairs + + return nil +} + +// NewConstHistogram returns a metric representing a Prometheus histogram with +// fixed values for the count, sum, and bucket counts. As those parameters +// cannot be changed, the returned value does not implement the Histogram +// interface (but only the Metric interface). Users of this package will not +// have much use for it in regular operations. However, when implementing custom +// Collectors, it is useful as a throw-away metric that is generated on the fly +// to send it to Prometheus in the Collect method. +// +// buckets is a map of upper bounds to cumulative counts, excluding the +Inf +// bucket. The +Inf bucket is implicit, and its value is equal to the provided count. +// +// NewConstHistogram returns an error if the length of labelValues is not +// consistent with the variable labels in Desc or if Desc is invalid. +func NewConstHistogram( + desc *Desc, + count uint64, + sum float64, + buckets map[float64]uint64, + labelValues ...string, +) (Metric, error) { + if desc.err != nil { + return nil, desc.err + } + if err := validateLabelValues(labelValues, len(desc.variableLabels)); err != nil { + return nil, err + } + return &constHistogram{ + desc: desc, + count: count, + sum: sum, + buckets: buckets, + labelPairs: MakeLabelPairs(desc, labelValues), + }, nil +} + +// MustNewConstHistogram is a version of NewConstHistogram that panics where +// NewConstHistogram would have returned an error. +func MustNewConstHistogram( + desc *Desc, + count uint64, + sum float64, + buckets map[float64]uint64, + labelValues ...string, +) Metric { + m, err := NewConstHistogram(desc, count, sum, buckets, labelValues...) + if err != nil { + panic(err) + } + return m +} + +type buckSort []*dto.Bucket + +func (s buckSort) Len() int { + return len(s) +} + +func (s buckSort) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s buckSort) Less(i, j int) bool { + return s[i].GetUpperBound() < s[j].GetUpperBound() +} + +// pickSchema returns the largest number n between -4 and 8 such that +// 2^(2^-n) is less or equal the provided bucketFactor. +// +// Special cases: +// - bucketFactor <= 1: panics. +// - bucketFactor < 2^(2^-8) (but > 1): still returns 8. +func pickSchema(bucketFactor float64) int32 { + if bucketFactor <= 1 { + panic(fmt.Errorf("bucketFactor %f is <=1", bucketFactor)) + } + floor := math.Floor(math.Log2(math.Log2(bucketFactor))) + switch { + case floor <= -8: + return 8 + case floor >= 4: + return -4 + default: + return -int32(floor) + } +} + +func makeBuckets(buckets *sync.Map) ([]*dto.BucketSpan, []int64) { + var ii []int + buckets.Range(func(k, v interface{}) bool { + ii = append(ii, k.(int)) + return true + }) + sort.Ints(ii) + + if len(ii) == 0 { + return nil, nil + } + + var ( + spans []*dto.BucketSpan + deltas []int64 + prevCount int64 + nextI int + ) + + appendDelta := func(count int64) { + *spans[len(spans)-1].Length++ + deltas = append(deltas, count-prevCount) + prevCount = count + } + + for n, i := range ii { + v, _ := buckets.Load(i) + count := atomic.LoadInt64(v.(*int64)) + // Multiple spans with only small gaps in between are probably + // encoded more efficiently as one larger span with a few empty + // buckets. Needs some research to find the sweet spot. For now, + // we assume that gaps of one ore two buckets should not create + // a new span. + iDelta := int32(i - nextI) + if n == 0 || iDelta > 2 { + // We have to create a new span, either because we are + // at the very beginning, or because we have found a gap + // of more than two buckets. + spans = append(spans, &dto.BucketSpan{ + Offset: proto.Int32(iDelta), + Length: proto.Uint32(0), + }) + } else { + // We have found a small gap (or no gap at all). + // Insert empty buckets as needed. + for j := int32(0); j < iDelta; j++ { + appendDelta(0) + } + } + appendDelta(count) + nextI = i + 1 + } + return spans, deltas +} + +// addToBucket increments the sparse bucket at key by the provided amount. It +// returns true if a new sparse bucket had to be created for that. +func addToBucket(buckets *sync.Map, key int, increment int64) bool { + if existingBucket, ok := buckets.Load(key); ok { + // Fast path without allocation. + atomic.AddInt64(existingBucket.(*int64), increment) + return false + } + // Bucket doesn't exist yet. Slow path allocating new counter. + newBucket := increment // TODO(beorn7): Check if this is sufficient to not let increment escape. + if actualBucket, loaded := buckets.LoadOrStore(key, &newBucket); loaded { + // The bucket was created concurrently in another goroutine. + // Have to increment after all. + atomic.AddInt64(actualBucket.(*int64), increment) + return false + } + return true +} + +// addAndReset returns a function to be used with sync.Map.Range of spare +// buckets in coldCounts. It increments the buckets in the provided hotBuckets +// according to the buckets ranged through. It then resets all buckets ranged +// through to 0 (but leaves them in place so that they don't need to get +// recreated on the next scrape). +func addAndReset(hotBuckets *sync.Map, bucketNumber *uint32) func(k, v interface{}) bool { + return func(k, v interface{}) bool { + bucket := v.(*int64) + if addToBucket(hotBuckets, k.(int), atomic.LoadInt64(bucket)) { + atomic.AddUint32(bucketNumber, 1) + } + atomic.StoreInt64(bucket, 0) + return true + } +} + +func deleteSyncMap(m *sync.Map) { + m.Range(func(k, v interface{}) bool { + m.Delete(k) + return true + }) +} + +func findSmallestKey(m *sync.Map) int { + result := math.MaxInt32 + m.Range(func(k, v interface{}) bool { + key := k.(int) + if key < result { + result = key + } + return true + }) + return result +} + +func getLe(key int, schema int32) float64 { + // Here a bit of context about the behavior for the last bucket counting + // regular numbers (called simply "last bucket" below) and the bucket + // counting observations of ±Inf (called "inf bucket" below, with a key + // one higher than that of the "last bucket"): + // + // If we apply the usual formula to the last bucket, its upper bound + // would be calculated as +Inf. The reason is that the max possible + // regular float64 number (math.MaxFloat64) doesn't coincide with one of + // the calculated bucket boundaries. So the calculated boundary has to + // be larger than math.MaxFloat64, and the only float64 larger than + // math.MaxFloat64 is +Inf. However, we want to count actual + // observations of ±Inf in the inf bucket. Therefore, we have to treat + // the upper bound of the last bucket specially and set it to + // math.MaxFloat64. (The upper bound of the inf bucket, with its key + // being one higher than that of the last bucket, naturally comes out as + // +Inf by the usual formula. So that's fine.) + // + // math.MaxFloat64 has a frac of 0.9999999999999999 and an exp of + // 1024. If there were a float64 number following math.MaxFloat64, it + // would have a frac of 1.0 and an exp of 1024, or equivalently a frac + // of 0.5 and an exp of 1025. However, since frac must be smaller than + // 1, and exp must be smaller than 1025, either representation overflows + // a float64. (Which, in turn, is the reason that math.MaxFloat64 is the + // largest possible float64. Q.E.D.) However, the formula for + // calculating the upper bound from the idx and schema of the last + // bucket results in precisely that. It is either frac=1.0 & exp=1024 + // (for schema < 0) or frac=0.5 & exp=1025 (for schema >=0). (This is, + // by the way, a power of two where the exponent itself is a power of + // two, 2¹⁰ in fact, which coinicides with a bucket boundary in all + // schemas.) So these are the special cases we have to catch below. + if schema < 0 { + exp := key << -schema + if exp == 1024 { + // This is the last bucket before the overflow bucket + // (for ±Inf observations). Return math.MaxFloat64 as + // explained above. + return math.MaxFloat64 + } + return math.Ldexp(1, exp) + } + + fracIdx := key & ((1 << schema) - 1) + frac := nativeHistogramBounds[schema][fracIdx] + exp := (key >> schema) + 1 + if frac == 0.5 && exp == 1025 { + // This is the last bucket before the overflow bucket (for ±Inf + // observations). Return math.MaxFloat64 as explained above. + return math.MaxFloat64 + } + return math.Ldexp(frac, exp) +} + +// waitForCooldown returns after the count field in the provided histogramCounts +// has reached the provided count value. +func waitForCooldown(count uint64, counts *histogramCounts) { + for count != atomic.LoadUint64(&counts.count) { + runtime.Gosched() // Let observations get work done. + } +} + +// atomicAddFloat adds the provided float atomically to another float +// represented by the bit pattern the bits pointer is pointing to. +func atomicAddFloat(bits *uint64, v float64) { + for { + loadedBits := atomic.LoadUint64(bits) + newBits := math.Float64bits(math.Float64frombits(loadedBits) + v) + if atomic.CompareAndSwapUint64(bits, loadedBits, newBits) { + break + } + } +} + +// atomicDecUint32 atomically decrements the uint32 p points to. See +// https://pkg.go.dev/sync/atomic#AddUint32 to understand how this is done. +func atomicDecUint32(p *uint32) { + atomic.AddUint32(p, ^uint32(0)) +} + +// addAndResetCounts adds certain fields (count, sum, conventional buckets, zero +// bucket) from the cold counts to the corresponding fields in the hot +// counts. Those fields are then reset to 0 in the cold counts. +func addAndResetCounts(hot, cold *histogramCounts) { + atomic.AddUint64(&hot.count, atomic.LoadUint64(&cold.count)) + atomic.StoreUint64(&cold.count, 0) + coldSum := math.Float64frombits(atomic.LoadUint64(&cold.sumBits)) + atomicAddFloat(&hot.sumBits, coldSum) + atomic.StoreUint64(&cold.sumBits, 0) + for i := range hot.buckets { + atomic.AddUint64(&hot.buckets[i], atomic.LoadUint64(&cold.buckets[i])) + atomic.StoreUint64(&cold.buckets[i], 0) + } + atomic.AddUint64(&hot.nativeHistogramZeroBucket, atomic.LoadUint64(&cold.nativeHistogramZeroBucket)) + atomic.StoreUint64(&cold.nativeHistogramZeroBucket, 0) +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/internal/difflib.go b/vendor/github.com/prometheus/client_golang/prometheus/internal/difflib.go new file mode 100644 index 0000000000..fd0750f2cf --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/internal/difflib.go @@ -0,0 +1,654 @@ +// Copyright 2022 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// It provides tools to compare sequences of strings and generate textual diffs. +// +// Maintaining `GetUnifiedDiffString` here because original repository +// (https://github.com/pmezard/go-difflib) is no loger maintained. +package internal + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func calculateRatio(matches, length int) float64 { + if length > 0 { + return 2.0 * float64(matches) / float64(length) + } + return 1.0 +} + +type Match struct { + A int + B int + Size int +} + +type OpCode struct { + Tag byte + I1 int + I2 int + J1 int + J2 int +} + +// SequenceMatcher compares sequence of strings. The basic +// algorithm predates, and is a little fancier than, an algorithm +// published in the late 1980's by Ratcliff and Obershelp under the +// hyperbolic name "gestalt pattern matching". The basic idea is to find +// the longest contiguous matching subsequence that contains no "junk" +// elements (R-O doesn't address junk). The same idea is then applied +// recursively to the pieces of the sequences to the left and to the right +// of the matching subsequence. This does not yield minimal edit +// sequences, but does tend to yield matches that "look right" to people. +// +// SequenceMatcher tries to compute a "human-friendly diff" between two +// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the +// longest *contiguous* & junk-free matching subsequence. That's what +// catches peoples' eyes. The Windows(tm) windiff has another interesting +// notion, pairing up elements that appear uniquely in each sequence. +// That, and the method here, appear to yield more intuitive difference +// reports than does diff. This method appears to be the least vulnerable +// to synching up on blocks of "junk lines", though (like blank lines in +// ordinary text files, or maybe "

" lines in HTML files). That may be +// because this is the only method of the 3 that has a *concept* of +// "junk" . +// +// Timing: Basic R-O is cubic time worst case and quadratic time expected +// case. SequenceMatcher is quadratic time for the worst case and has +// expected-case behavior dependent in a complicated way on how many +// elements the sequences have in common; best case time is linear. +type SequenceMatcher struct { + a []string + b []string + b2j map[string][]int + IsJunk func(string) bool + autoJunk bool + bJunk map[string]struct{} + matchingBlocks []Match + fullBCount map[string]int + bPopular map[string]struct{} + opCodes []OpCode +} + +func NewMatcher(a, b []string) *SequenceMatcher { + m := SequenceMatcher{autoJunk: true} + m.SetSeqs(a, b) + return &m +} + +func NewMatcherWithJunk(a, b []string, autoJunk bool, + isJunk func(string) bool, +) *SequenceMatcher { + m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk} + m.SetSeqs(a, b) + return &m +} + +// Set two sequences to be compared. +func (m *SequenceMatcher) SetSeqs(a, b []string) { + m.SetSeq1(a) + m.SetSeq2(b) +} + +// Set the first sequence to be compared. The second sequence to be compared is +// not changed. +// +// SequenceMatcher computes and caches detailed information about the second +// sequence, so if you want to compare one sequence S against many sequences, +// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other +// sequences. +// +// See also SetSeqs() and SetSeq2(). +func (m *SequenceMatcher) SetSeq1(a []string) { + if &a == &m.a { + return + } + m.a = a + m.matchingBlocks = nil + m.opCodes = nil +} + +// Set the second sequence to be compared. The first sequence to be compared is +// not changed. +func (m *SequenceMatcher) SetSeq2(b []string) { + if &b == &m.b { + return + } + m.b = b + m.matchingBlocks = nil + m.opCodes = nil + m.fullBCount = nil + m.chainB() +} + +func (m *SequenceMatcher) chainB() { + // Populate line -> index mapping + b2j := map[string][]int{} + for i, s := range m.b { + indices := b2j[s] + indices = append(indices, i) + b2j[s] = indices + } + + // Purge junk elements + m.bJunk = map[string]struct{}{} + if m.IsJunk != nil { + junk := m.bJunk + for s := range b2j { + if m.IsJunk(s) { + junk[s] = struct{}{} + } + } + for s := range junk { + delete(b2j, s) + } + } + + // Purge remaining popular elements + popular := map[string]struct{}{} + n := len(m.b) + if m.autoJunk && n >= 200 { + ntest := n/100 + 1 + for s, indices := range b2j { + if len(indices) > ntest { + popular[s] = struct{}{} + } + } + for s := range popular { + delete(b2j, s) + } + } + m.bPopular = popular + m.b2j = b2j +} + +func (m *SequenceMatcher) isBJunk(s string) bool { + _, ok := m.bJunk[s] + return ok +} + +// Find longest matching block in a[alo:ahi] and b[blo:bhi]. +// +// If IsJunk is not defined: +// +// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where +// +// alo <= i <= i+k <= ahi +// blo <= j <= j+k <= bhi +// +// and for all (i',j',k') meeting those conditions, +// +// k >= k' +// i <= i' +// and if i == i', j <= j' +// +// In other words, of all maximal matching blocks, return one that +// starts earliest in a, and of all those maximal matching blocks that +// start earliest in a, return the one that starts earliest in b. +// +// If IsJunk is defined, first the longest matching block is +// determined as above, but with the additional restriction that no +// junk element appears in the block. Then that block is extended as +// far as possible by matching (only) junk elements on both sides. So +// the resulting block never matches on junk except as identical junk +// happens to be adjacent to an "interesting" match. +// +// If no blocks match, return (alo, blo, 0). +func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match { + // CAUTION: stripping common prefix or suffix would be incorrect. + // E.g., + // ab + // acab + // Longest matching block is "ab", but if common prefix is + // stripped, it's "a" (tied with "b"). UNIX(tm) diff does so + // strip, so ends up claiming that ab is changed to acab by + // inserting "ca" in the middle. That's minimal but unintuitive: + // "it's obvious" that someone inserted "ac" at the front. + // Windiff ends up at the same place as diff, but by pairing up + // the unique 'b's and then matching the first two 'a's. + besti, bestj, bestsize := alo, blo, 0 + + // find longest junk-free match + // during an iteration of the loop, j2len[j] = length of longest + // junk-free match ending with a[i-1] and b[j] + j2len := map[int]int{} + for i := alo; i != ahi; i++ { + // look at all instances of a[i] in b; note that because + // b2j has no junk keys, the loop is skipped if a[i] is junk + newj2len := map[int]int{} + for _, j := range m.b2j[m.a[i]] { + // a[i] matches b[j] + if j < blo { + continue + } + if j >= bhi { + break + } + k := j2len[j-1] + 1 + newj2len[j] = k + if k > bestsize { + besti, bestj, bestsize = i-k+1, j-k+1, k + } + } + j2len = newj2len + } + + // Extend the best by non-junk elements on each end. In particular, + // "popular" non-junk elements aren't in b2j, which greatly speeds + // the inner loop above, but also means "the best" match so far + // doesn't contain any junk *or* popular non-junk elements. + for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + !m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize++ + } + + // Now that we have a wholly interesting match (albeit possibly + // empty!), we may as well suck up the matching junk on each + // side of it too. Can't think of a good reason not to, and it + // saves post-processing the (possibly considerable) expense of + // figuring out what to do with it. In the case of an empty + // interesting match, this is clearly the right thing to do, + // because no other kind of match is possible in the regions. + for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) && + m.a[besti-1] == m.b[bestj-1] { + besti, bestj, bestsize = besti-1, bestj-1, bestsize+1 + } + for besti+bestsize < ahi && bestj+bestsize < bhi && + m.isBJunk(m.b[bestj+bestsize]) && + m.a[besti+bestsize] == m.b[bestj+bestsize] { + bestsize++ + } + + return Match{A: besti, B: bestj, Size: bestsize} +} + +// Return list of triples describing matching subsequences. +// +// Each triple is of the form (i, j, n), and means that +// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in +// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are +// adjacent triples in the list, and the second is not the last triple in the +// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe +// adjacent equal blocks. +// +// The last triple is a dummy, (len(a), len(b), 0), and is the only +// triple with n==0. +func (m *SequenceMatcher) GetMatchingBlocks() []Match { + if m.matchingBlocks != nil { + return m.matchingBlocks + } + + var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match + matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match { + match := m.findLongestMatch(alo, ahi, blo, bhi) + i, j, k := match.A, match.B, match.Size + if match.Size > 0 { + if alo < i && blo < j { + matched = matchBlocks(alo, i, blo, j, matched) + } + matched = append(matched, match) + if i+k < ahi && j+k < bhi { + matched = matchBlocks(i+k, ahi, j+k, bhi, matched) + } + } + return matched + } + matched := matchBlocks(0, len(m.a), 0, len(m.b), nil) + + // It's possible that we have adjacent equal blocks in the + // matching_blocks list now. + nonAdjacent := []Match{} + i1, j1, k1 := 0, 0, 0 + for _, b := range matched { + // Is this block adjacent to i1, j1, k1? + i2, j2, k2 := b.A, b.B, b.Size + if i1+k1 == i2 && j1+k1 == j2 { + // Yes, so collapse them -- this just increases the length of + // the first block by the length of the second, and the first + // block so lengthened remains the block to compare against. + k1 += k2 + } else { + // Not adjacent. Remember the first block (k1==0 means it's + // the dummy we started with), and make the second block the + // new block to compare against. + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + i1, j1, k1 = i2, j2, k2 + } + } + if k1 > 0 { + nonAdjacent = append(nonAdjacent, Match{i1, j1, k1}) + } + + nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0}) + m.matchingBlocks = nonAdjacent + return m.matchingBlocks +} + +// Return list of 5-tuples describing how to turn a into b. +// +// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple +// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the +// tuple preceding it, and likewise for j1 == the previous j2. +// +// The tags are characters, with these meanings: +// +// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2] +// +// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case. +// +// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case. +// +// 'e' (equal): a[i1:i2] == b[j1:j2] +func (m *SequenceMatcher) GetOpCodes() []OpCode { + if m.opCodes != nil { + return m.opCodes + } + i, j := 0, 0 + matching := m.GetMatchingBlocks() + opCodes := make([]OpCode, 0, len(matching)) + for _, m := range matching { + // invariant: we've pumped out correct diffs to change + // a[:i] into b[:j], and the next matching block is + // a[ai:ai+size] == b[bj:bj+size]. So we need to pump + // out a diff to change a[i:ai] into b[j:bj], pump out + // the matching block, and move (i,j) beyond the match + ai, bj, size := m.A, m.B, m.Size + tag := byte(0) + if i < ai && j < bj { + tag = 'r' + } else if i < ai { + tag = 'd' + } else if j < bj { + tag = 'i' + } + if tag > 0 { + opCodes = append(opCodes, OpCode{tag, i, ai, j, bj}) + } + i, j = ai+size, bj+size + // the list of matching blocks is terminated by a + // sentinel with size 0 + if size > 0 { + opCodes = append(opCodes, OpCode{'e', ai, i, bj, j}) + } + } + m.opCodes = opCodes + return m.opCodes +} + +// Isolate change clusters by eliminating ranges with no changes. +// +// Return a generator of groups with up to n lines of context. +// Each group is in the same format as returned by GetOpCodes(). +func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode { + if n < 0 { + n = 3 + } + codes := m.GetOpCodes() + if len(codes) == 0 { + codes = []OpCode{{'e', 0, 1, 0, 1}} + } + // Fixup leading and trailing groups if they show no changes. + if codes[0].Tag == 'e' { + c := codes[0] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2} + } + if codes[len(codes)-1].Tag == 'e' { + c := codes[len(codes)-1] + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)} + } + nn := n + n + groups := [][]OpCode{} + group := []OpCode{} + for _, c := range codes { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + // End the current group and start a new one whenever + // there is a large range with no changes. + if c.Tag == 'e' && i2-i1 > nn { + group = append(group, OpCode{ + c.Tag, i1, min(i2, i1+n), + j1, min(j2, j1+n), + }) + groups = append(groups, group) + group = []OpCode{} + i1, j1 = max(i1, i2-n), max(j1, j2-n) + } + group = append(group, OpCode{c.Tag, i1, i2, j1, j2}) + } + if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') { + groups = append(groups, group) + } + return groups +} + +// Return a measure of the sequences' similarity (float in [0,1]). +// +// Where T is the total number of elements in both sequences, and +// M is the number of matches, this is 2.0*M / T. +// Note that this is 1 if the sequences are identical, and 0 if +// they have nothing in common. +// +// .Ratio() is expensive to compute if you haven't already computed +// .GetMatchingBlocks() or .GetOpCodes(), in which case you may +// want to try .QuickRatio() or .RealQuickRation() first to get an +// upper bound. +func (m *SequenceMatcher) Ratio() float64 { + matches := 0 + for _, m := range m.GetMatchingBlocks() { + matches += m.Size + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// Return an upper bound on ratio() relatively quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute. +func (m *SequenceMatcher) QuickRatio() float64 { + // viewing a and b as multisets, set matches to the cardinality + // of their intersection; this counts the number of matches + // without regard to order, so is clearly an upper bound + if m.fullBCount == nil { + m.fullBCount = map[string]int{} + for _, s := range m.b { + m.fullBCount[s]++ + } + } + + // avail[x] is the number of times x appears in 'b' less the + // number of times we've seen it in 'a' so far ... kinda + avail := map[string]int{} + matches := 0 + for _, s := range m.a { + n, ok := avail[s] + if !ok { + n = m.fullBCount[s] + } + avail[s] = n - 1 + if n > 0 { + matches++ + } + } + return calculateRatio(matches, len(m.a)+len(m.b)) +} + +// Return an upper bound on ratio() very quickly. +// +// This isn't defined beyond that it is an upper bound on .Ratio(), and +// is faster to compute than either .Ratio() or .QuickRatio(). +func (m *SequenceMatcher) RealQuickRatio() float64 { + la, lb := len(m.a), len(m.b) + return calculateRatio(min(la, lb), la+lb) +} + +// Convert range to the "ed" format +func formatRangeUnified(start, stop int) string { + // Per the diff spec at http://www.unix.org/single_unix_specification/ + beginning := start + 1 // lines start numbering with one + length := stop - start + if length == 1 { + return fmt.Sprintf("%d", beginning) + } + if length == 0 { + beginning-- // empty ranges begin at line just before the range + } + return fmt.Sprintf("%d,%d", beginning, length) +} + +// Unified diff parameters +type UnifiedDiff struct { + A []string // First sequence lines + FromFile string // First file name + FromDate string // First file time + B []string // Second sequence lines + ToFile string // Second file name + ToDate string // Second file time + Eol string // Headers end of line, defaults to LF + Context int // Number of context lines +} + +// Compare two sequences of lines; generate the delta as a unified diff. +// +// Unified diffs are a compact way of showing line changes and a few +// lines of context. The number of context lines is set by 'n' which +// defaults to three. +// +// By default, the diff control lines (those with ---, +++, or @@) are +// created with a trailing newline. This is helpful so that inputs +// created from file.readlines() result in diffs that are suitable for +// file.writelines() since both the inputs and outputs have trailing +// newlines. +// +// For inputs that do not have trailing newlines, set the lineterm +// argument to "" so that the output will be uniformly newline free. +// +// The unidiff format normally has a header for filenames and modification +// times. Any or all of these may be specified using strings for +// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'. +// The modification times are normally expressed in the ISO 8601 format. +func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error { + buf := bufio.NewWriter(writer) + defer buf.Flush() + wf := func(format string, args ...interface{}) error { + _, err := buf.WriteString(fmt.Sprintf(format, args...)) + return err + } + ws := func(s string) error { + _, err := buf.WriteString(s) + return err + } + + if len(diff.Eol) == 0 { + diff.Eol = "\n" + } + + started := false + m := NewMatcher(diff.A, diff.B) + for _, g := range m.GetGroupedOpCodes(diff.Context) { + if !started { + started = true + fromDate := "" + if len(diff.FromDate) > 0 { + fromDate = "\t" + diff.FromDate + } + toDate := "" + if len(diff.ToDate) > 0 { + toDate = "\t" + diff.ToDate + } + if diff.FromFile != "" || diff.ToFile != "" { + err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol) + if err != nil { + return err + } + err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol) + if err != nil { + return err + } + } + } + first, last := g[0], g[len(g)-1] + range1 := formatRangeUnified(first.I1, last.I2) + range2 := formatRangeUnified(first.J1, last.J2) + if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil { + return err + } + for _, c := range g { + i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2 + if c.Tag == 'e' { + for _, line := range diff.A[i1:i2] { + if err := ws(" " + line); err != nil { + return err + } + } + continue + } + if c.Tag == 'r' || c.Tag == 'd' { + for _, line := range diff.A[i1:i2] { + if err := ws("-" + line); err != nil { + return err + } + } + } + if c.Tag == 'r' || c.Tag == 'i' { + for _, line := range diff.B[j1:j2] { + if err := ws("+" + line); err != nil { + return err + } + } + } + } + } + return nil +} + +// Like WriteUnifiedDiff but returns the diff a string. +func GetUnifiedDiffString(diff UnifiedDiff) (string, error) { + w := &bytes.Buffer{} + err := WriteUnifiedDiff(w, diff) + return w.String(), err +} + +// Split a string on "\n" while preserving them. The output can be used +// as input for UnifiedDiff and ContextDiff structures. +func SplitLines(s string) []string { + lines := strings.SplitAfter(s, "\n") + lines[len(lines)-1] += "\n" + return lines +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/labels.go b/vendor/github.com/prometheus/client_golang/prometheus/labels.go new file mode 100644 index 0000000000..63ff8683ce --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/labels.go @@ -0,0 +1,160 @@ +// Copyright 2018 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "errors" + "fmt" + "strings" + "unicode/utf8" + + "github.com/prometheus/common/model" +) + +// Labels represents a collection of label name -> value mappings. This type is +// commonly used with the With(Labels) and GetMetricWith(Labels) methods of +// metric vector Collectors, e.g.: +// +// myVec.With(Labels{"code": "404", "method": "GET"}).Add(42) +// +// The other use-case is the specification of constant label pairs in Opts or to +// create a Desc. +type Labels map[string]string + +// ConstrainedLabels represents a label name and its constrain function +// to normalize label values. This type is commonly used when constructing +// metric vector Collectors. +type ConstrainedLabel struct { + Name string + Constraint func(string) string +} + +func (cl ConstrainedLabel) Constrain(v string) string { + if cl.Constraint == nil { + return v + } + return cl.Constraint(v) +} + +// ConstrainableLabels is an interface that allows creating of labels that can +// be optionally constrained. +// +// prometheus.V2().NewCounterVec(CounterVecOpts{ +// CounterOpts: {...}, // Usual CounterOpts fields +// VariableLabels: []ConstrainedLabels{ +// {Name: "A"}, +// {Name: "B", Constraint: func(v string) string { ... }}, +// }, +// }) +type ConstrainableLabels interface { + constrainedLabels() ConstrainedLabels + labelNames() []string +} + +// ConstrainedLabels represents a collection of label name -> constrain function +// to normalize label values. This type is commonly used when constructing +// metric vector Collectors. +type ConstrainedLabels []ConstrainedLabel + +func (cls ConstrainedLabels) constrainedLabels() ConstrainedLabels { + return cls +} + +func (cls ConstrainedLabels) labelNames() []string { + names := make([]string, len(cls)) + for i, label := range cls { + names[i] = label.Name + } + return names +} + +// UnconstrainedLabels represents collection of label without any constraint on +// their value. Thus, it is simply a collection of label names. +// +// UnconstrainedLabels([]string{ "A", "B" }) +// +// is equivalent to +// +// ConstrainedLabels { +// { Name: "A" }, +// { Name: "B" }, +// } +type UnconstrainedLabels []string + +func (uls UnconstrainedLabels) constrainedLabels() ConstrainedLabels { + constrainedLabels := make([]ConstrainedLabel, len(uls)) + for i, l := range uls { + constrainedLabels[i] = ConstrainedLabel{Name: l} + } + return constrainedLabels +} + +func (uls UnconstrainedLabels) labelNames() []string { + return uls +} + +// reservedLabelPrefix is a prefix which is not legal in user-supplied +// label names. +const reservedLabelPrefix = "__" + +var errInconsistentCardinality = errors.New("inconsistent label cardinality") + +func makeInconsistentCardinalityError(fqName string, labels, labelValues []string) error { + return fmt.Errorf( + "%w: %q has %d variable labels named %q but %d values %q were provided", + errInconsistentCardinality, fqName, + len(labels), labels, + len(labelValues), labelValues, + ) +} + +func validateValuesInLabels(labels Labels, expectedNumberOfValues int) error { + if len(labels) != expectedNumberOfValues { + return fmt.Errorf( + "%w: expected %d label values but got %d in %#v", + errInconsistentCardinality, expectedNumberOfValues, + len(labels), labels, + ) + } + + for name, val := range labels { + if !utf8.ValidString(val) { + return fmt.Errorf("label %s: value %q is not valid UTF-8", name, val) + } + } + + return nil +} + +func validateLabelValues(vals []string, expectedNumberOfValues int) error { + if len(vals) != expectedNumberOfValues { + return fmt.Errorf( + "%w: expected %d label values but got %d in %#v", + errInconsistentCardinality, expectedNumberOfValues, + len(vals), vals, + ) + } + + for _, val := range vals { + if !utf8.ValidString(val) { + return fmt.Errorf("label value %q is not valid UTF-8", val) + } + } + + return nil +} + +func checkLabelName(l string) bool { + return model.LabelName(l).IsValid() && !strings.HasPrefix(l, reservedLabelPrefix) +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/metric.go b/vendor/github.com/prometheus/client_golang/prometheus/metric.go new file mode 100644 index 0000000000..07bbc9d768 --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/metric.go @@ -0,0 +1,254 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "errors" + "math" + "sort" + "strings" + "time" + + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/model" + "google.golang.org/protobuf/proto" +) + +var separatorByteSlice = []byte{model.SeparatorByte} // For convenient use with xxhash. + +// A Metric models a single sample value with its meta data being exported to +// Prometheus. Implementations of Metric in this package are Gauge, Counter, +// Histogram, Summary, and Untyped. +type Metric interface { + // Desc returns the descriptor for the Metric. This method idempotently + // returns the same descriptor throughout the lifetime of the + // Metric. The returned descriptor is immutable by contract. A Metric + // unable to describe itself must return an invalid descriptor (created + // with NewInvalidDesc). + Desc() *Desc + // Write encodes the Metric into a "Metric" Protocol Buffer data + // transmission object. + // + // Metric implementations must observe concurrency safety as reads of + // this metric may occur at any time, and any blocking occurs at the + // expense of total performance of rendering all registered + // metrics. Ideally, Metric implementations should support concurrent + // readers. + // + // While populating dto.Metric, it is the responsibility of the + // implementation to ensure validity of the Metric protobuf (like valid + // UTF-8 strings or syntactically valid metric and label names). It is + // recommended to sort labels lexicographically. Callers of Write should + // still make sure of sorting if they depend on it. + Write(*dto.Metric) error + // TODO(beorn7): The original rationale of passing in a pre-allocated + // dto.Metric protobuf to save allocations has disappeared. The + // signature of this method should be changed to "Write() (*dto.Metric, + // error)". +} + +// Opts bundles the options for creating most Metric types. Each metric +// implementation XXX has its own XXXOpts type, but in most cases, it is just +// an alias of this type (which might change when the requirement arises.) +// +// It is mandatory to set Name to a non-empty string. All other fields are +// optional and can safely be left at their zero value, although it is strongly +// encouraged to set a Help string. +type Opts struct { + // Namespace, Subsystem, and Name are components of the fully-qualified + // name of the Metric (created by joining these components with + // "_"). Only Name is mandatory, the others merely help structuring the + // name. Note that the fully-qualified name of the metric must be a + // valid Prometheus metric name. + Namespace string + Subsystem string + Name string + + // Help provides information about this metric. + // + // Metrics with the same fully-qualified name must have the same Help + // string. + Help string + + // ConstLabels are used to attach fixed labels to this metric. Metrics + // with the same fully-qualified name must have the same label names in + // their ConstLabels. + // + // ConstLabels are only used rarely. In particular, do not use them to + // attach the same labels to all your metrics. Those use cases are + // better covered by target labels set by the scraping Prometheus + // server, or by one specific metric (e.g. a build_info or a + // machine_role metric). See also + // https://prometheus.io/docs/instrumenting/writing_exporters/#target-labels-not-static-scraped-labels + ConstLabels Labels +} + +// BuildFQName joins the given three name components by "_". Empty name +// components are ignored. If the name parameter itself is empty, an empty +// string is returned, no matter what. Metric implementations included in this +// library use this function internally to generate the fully-qualified metric +// name from the name component in their Opts. Users of the library will only +// need this function if they implement their own Metric or instantiate a Desc +// (with NewDesc) directly. +func BuildFQName(namespace, subsystem, name string) string { + if name == "" { + return "" + } + switch { + case namespace != "" && subsystem != "": + return strings.Join([]string{namespace, subsystem, name}, "_") + case namespace != "": + return strings.Join([]string{namespace, name}, "_") + case subsystem != "": + return strings.Join([]string{subsystem, name}, "_") + } + return name +} + +type invalidMetric struct { + desc *Desc + err error +} + +// NewInvalidMetric returns a metric whose Write method always returns the +// provided error. It is useful if a Collector finds itself unable to collect +// a metric and wishes to report an error to the registry. +func NewInvalidMetric(desc *Desc, err error) Metric { + return &invalidMetric{desc, err} +} + +func (m *invalidMetric) Desc() *Desc { return m.desc } + +func (m *invalidMetric) Write(*dto.Metric) error { return m.err } + +type timestampedMetric struct { + Metric + t time.Time +} + +func (m timestampedMetric) Write(pb *dto.Metric) error { + e := m.Metric.Write(pb) + pb.TimestampMs = proto.Int64(m.t.Unix()*1000 + int64(m.t.Nanosecond()/1000000)) + return e +} + +// NewMetricWithTimestamp returns a new Metric wrapping the provided Metric in a +// way that it has an explicit timestamp set to the provided Time. This is only +// useful in rare cases as the timestamp of a Prometheus metric should usually +// be set by the Prometheus server during scraping. Exceptions include mirroring +// metrics with given timestamps from other metric +// sources. +// +// NewMetricWithTimestamp works best with MustNewConstMetric, +// MustNewConstHistogram, and MustNewConstSummary, see example. +// +// Currently, the exposition formats used by Prometheus are limited to +// millisecond resolution. Thus, the provided time will be rounded down to the +// next full millisecond value. +func NewMetricWithTimestamp(t time.Time, m Metric) Metric { + return timestampedMetric{Metric: m, t: t} +} + +type withExemplarsMetric struct { + Metric + + exemplars []*dto.Exemplar +} + +func (m *withExemplarsMetric) Write(pb *dto.Metric) error { + if err := m.Metric.Write(pb); err != nil { + return err + } + + switch { + case pb.Counter != nil: + pb.Counter.Exemplar = m.exemplars[len(m.exemplars)-1] + case pb.Histogram != nil: + for _, e := range m.exemplars { + // pb.Histogram.Bucket are sorted by UpperBound. + i := sort.Search(len(pb.Histogram.Bucket), func(i int) bool { + return pb.Histogram.Bucket[i].GetUpperBound() >= e.GetValue() + }) + if i < len(pb.Histogram.Bucket) { + pb.Histogram.Bucket[i].Exemplar = e + } else { + // The +Inf bucket should be explicitly added if there is an exemplar for it, similar to non-const histogram logic in https://github.com/prometheus/client_golang/blob/main/prometheus/histogram.go#L357-L365. + b := &dto.Bucket{ + CumulativeCount: proto.Uint64(pb.Histogram.GetSampleCount()), + UpperBound: proto.Float64(math.Inf(1)), + Exemplar: e, + } + pb.Histogram.Bucket = append(pb.Histogram.Bucket, b) + } + } + default: + // TODO(bwplotka): Implement Gauge? + return errors.New("cannot inject exemplar into Gauge, Summary or Untyped") + } + + return nil +} + +// Exemplar is easier to use, user-facing representation of *dto.Exemplar. +type Exemplar struct { + Value float64 + Labels Labels + // Optional. + // Default value (time.Time{}) indicates its empty, which should be + // understood as time.Now() time at the moment of creation of metric. + Timestamp time.Time +} + +// NewMetricWithExemplars returns a new Metric wrapping the provided Metric with given +// exemplars. Exemplars are validated. +// +// Only last applicable exemplar is injected from the list. +// For example for Counter it means last exemplar is injected. +// For Histogram, it means last applicable exemplar for each bucket is injected. +// +// NewMetricWithExemplars works best with MustNewConstMetric and +// MustNewConstHistogram, see example. +func NewMetricWithExemplars(m Metric, exemplars ...Exemplar) (Metric, error) { + if len(exemplars) == 0 { + return nil, errors.New("no exemplar was passed for NewMetricWithExemplars") + } + + var ( + now = time.Now() + exs = make([]*dto.Exemplar, len(exemplars)) + err error + ) + for i, e := range exemplars { + ts := e.Timestamp + if ts == (time.Time{}) { + ts = now + } + exs[i], err = newExemplar(e.Value, ts, e.Labels) + if err != nil { + return nil, err + } + } + + return &withExemplarsMetric{Metric: m, exemplars: exs}, nil +} + +// MustNewMetricWithExemplars is a version of NewMetricWithExemplars that panics where +// NewMetricWithExemplars would have returned an error. +func MustNewMetricWithExemplars(m Metric, exemplars ...Exemplar) Metric { + ret, err := NewMetricWithExemplars(m, exemplars...) + if err != nil { + panic(err) + } + return ret +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/promhttp/instrument_server.go b/vendor/github.com/prometheus/client_golang/prometheus/promhttp/instrument_server.go new file mode 100644 index 0000000000..3793036ad0 --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/promhttp/instrument_server.go @@ -0,0 +1,579 @@ +// Copyright 2017 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package promhttp + +import ( + "errors" + "net/http" + "strconv" + "strings" + "time" + + dto "github.com/prometheus/client_model/go" + + "github.com/prometheus/client_golang/prometheus" +) + +// magicString is used for the hacky label test in checkLabels. Remove once fixed. +const magicString = "zZgWfBxLqvG8kc8IMv3POi2Bb0tZI3vAnBx+gBaFi9FyPzB/CzKUer1yufDa" + +// observeWithExemplar is a wrapper for [prometheus.ExemplarAdder.ExemplarObserver], +// which falls back to [prometheus.Observer.Observe] if no labels are provided. +func observeWithExemplar(obs prometheus.Observer, val float64, labels map[string]string) { + if labels == nil { + obs.Observe(val) + return + } + obs.(prometheus.ExemplarObserver).ObserveWithExemplar(val, labels) +} + +// addWithExemplar is a wrapper for [prometheus.ExemplarAdder.AddWithExemplar], +// which falls back to [prometheus.Counter.Add] if no labels are provided. +func addWithExemplar(obs prometheus.Counter, val float64, labels map[string]string) { + if labels == nil { + obs.Add(val) + return + } + obs.(prometheus.ExemplarAdder).AddWithExemplar(val, labels) +} + +// InstrumentHandlerInFlight is a middleware that wraps the provided +// http.Handler. It sets the provided prometheus.Gauge to the number of +// requests currently handled by the wrapped http.Handler. +// +// See the example for InstrumentHandlerDuration for example usage. +func InstrumentHandlerInFlight(g prometheus.Gauge, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + g.Inc() + defer g.Dec() + next.ServeHTTP(w, r) + }) +} + +// InstrumentHandlerDuration is a middleware that wraps the provided +// http.Handler to observe the request duration with the provided ObserverVec. +// The ObserverVec must have valid metric and label names and must have zero, +// one, or two non-const non-curried labels. For those, the only allowed label +// names are "code" and "method". The function panics otherwise. For the "method" +// label a predefined default label value set is used to filter given values. +// Values besides predefined values will count as `unknown` method. +// `WithExtraMethods` can be used to add more methods to the set. The Observe +// method of the Observer in the ObserverVec is called with the request duration +// in seconds. Partitioning happens by HTTP status code and/or HTTP method if +// the respective instance label names are present in the ObserverVec. For +// unpartitioned observations, use an ObserverVec with zero labels. Note that +// partitioning of Histograms is expensive and should be used judiciously. +// +// If the wrapped Handler does not set a status code, a status code of 200 is assumed. +// +// If the wrapped Handler panics, no values are reported. +// +// Note that this method is only guaranteed to never observe negative durations +// if used with Go1.9+. +func InstrumentHandlerDuration(obs prometheus.ObserverVec, next http.Handler, opts ...Option) http.HandlerFunc { + hOpts := defaultOptions() + for _, o := range opts { + o.apply(hOpts) + } + + // Curry the observer with dynamic labels before checking the remaining labels. + code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels())) + + if code { + return func(w http.ResponseWriter, r *http.Request) { + now := time.Now() + d := newDelegator(w, nil) + next.ServeHTTP(d, r) + + l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + observeWithExemplar(obs.With(l), time.Since(now).Seconds(), hOpts.getExemplarFn(r.Context())) + } + } + + return func(w http.ResponseWriter, r *http.Request) { + now := time.Now() + next.ServeHTTP(w, r) + l := labels(code, method, r.Method, 0, hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + observeWithExemplar(obs.With(l), time.Since(now).Seconds(), hOpts.getExemplarFn(r.Context())) + } +} + +// InstrumentHandlerCounter is a middleware that wraps the provided http.Handler +// to observe the request result with the provided CounterVec. The CounterVec +// must have valid metric and label names and must have zero, one, or two +// non-const non-curried labels. For those, the only allowed label names are +// "code" and "method". The function panics otherwise. For the "method" +// label a predefined default label value set is used to filter given values. +// Values besides predefined values will count as `unknown` method. +// `WithExtraMethods` can be used to add more methods to the set. Partitioning of the +// CounterVec happens by HTTP status code and/or HTTP method if the respective +// instance label names are present in the CounterVec. For unpartitioned +// counting, use a CounterVec with zero labels. +// +// If the wrapped Handler does not set a status code, a status code of 200 is assumed. +// +// If the wrapped Handler panics, the Counter is not incremented. +// +// See the example for InstrumentHandlerDuration for example usage. +func InstrumentHandlerCounter(counter *prometheus.CounterVec, next http.Handler, opts ...Option) http.HandlerFunc { + hOpts := defaultOptions() + for _, o := range opts { + o.apply(hOpts) + } + + // Curry the counter with dynamic labels before checking the remaining labels. + code, method := checkLabels(counter.MustCurryWith(hOpts.emptyDynamicLabels())) + + if code { + return func(w http.ResponseWriter, r *http.Request) { + d := newDelegator(w, nil) + next.ServeHTTP(d, r) + + l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + addWithExemplar(counter.With(l), 1, hOpts.getExemplarFn(r.Context())) + } + } + + return func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + + l := labels(code, method, r.Method, 0, hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + addWithExemplar(counter.With(l), 1, hOpts.getExemplarFn(r.Context())) + } +} + +// InstrumentHandlerTimeToWriteHeader is a middleware that wraps the provided +// http.Handler to observe with the provided ObserverVec the request duration +// until the response headers are written. The ObserverVec must have valid +// metric and label names and must have zero, one, or two non-const non-curried +// labels. For those, the only allowed label names are "code" and "method". The +// function panics otherwise. For the "method" label a predefined default label +// value set is used to filter given values. Values besides predefined values +// will count as `unknown` method.`WithExtraMethods` can be used to add more +// methods to the set. The Observe method of the Observer in the +// ObserverVec is called with the request duration in seconds. Partitioning +// happens by HTTP status code and/or HTTP method if the respective instance +// label names are present in the ObserverVec. For unpartitioned observations, +// use an ObserverVec with zero labels. Note that partitioning of Histograms is +// expensive and should be used judiciously. +// +// If the wrapped Handler panics before calling WriteHeader, no value is +// reported. +// +// Note that this method is only guaranteed to never observe negative durations +// if used with Go1.9+. +// +// See the example for InstrumentHandlerDuration for example usage. +func InstrumentHandlerTimeToWriteHeader(obs prometheus.ObserverVec, next http.Handler, opts ...Option) http.HandlerFunc { + hOpts := defaultOptions() + for _, o := range opts { + o.apply(hOpts) + } + + // Curry the observer with dynamic labels before checking the remaining labels. + code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels())) + + return func(w http.ResponseWriter, r *http.Request) { + now := time.Now() + d := newDelegator(w, func(status int) { + l := labels(code, method, r.Method, status, hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + observeWithExemplar(obs.With(l), time.Since(now).Seconds(), hOpts.getExemplarFn(r.Context())) + }) + next.ServeHTTP(d, r) + } +} + +// InstrumentHandlerRequestSize is a middleware that wraps the provided +// http.Handler to observe the request size with the provided ObserverVec. The +// ObserverVec must have valid metric and label names and must have zero, one, +// or two non-const non-curried labels. For those, the only allowed label names +// are "code" and "method". The function panics otherwise. For the "method" +// label a predefined default label value set is used to filter given values. +// Values besides predefined values will count as `unknown` method. +// `WithExtraMethods` can be used to add more methods to the set. The Observe +// method of the Observer in the ObserverVec is called with the request size in +// bytes. Partitioning happens by HTTP status code and/or HTTP method if the +// respective instance label names are present in the ObserverVec. For +// unpartitioned observations, use an ObserverVec with zero labels. Note that +// partitioning of Histograms is expensive and should be used judiciously. +// +// If the wrapped Handler does not set a status code, a status code of 200 is assumed. +// +// If the wrapped Handler panics, no values are reported. +// +// See the example for InstrumentHandlerDuration for example usage. +func InstrumentHandlerRequestSize(obs prometheus.ObserverVec, next http.Handler, opts ...Option) http.HandlerFunc { + hOpts := defaultOptions() + for _, o := range opts { + o.apply(hOpts) + } + + // Curry the observer with dynamic labels before checking the remaining labels. + code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels())) + + if code { + return func(w http.ResponseWriter, r *http.Request) { + d := newDelegator(w, nil) + next.ServeHTTP(d, r) + size := computeApproximateRequestSize(r) + + l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + observeWithExemplar(obs.With(l), float64(size), hOpts.getExemplarFn(r.Context())) + } + } + + return func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + size := computeApproximateRequestSize(r) + + l := labels(code, method, r.Method, 0, hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + observeWithExemplar(obs.With(l), float64(size), hOpts.getExemplarFn(r.Context())) + } +} + +// InstrumentHandlerResponseSize is a middleware that wraps the provided +// http.Handler to observe the response size with the provided ObserverVec. The +// ObserverVec must have valid metric and label names and must have zero, one, +// or two non-const non-curried labels. For those, the only allowed label names +// are "code" and "method". The function panics otherwise. For the "method" +// label a predefined default label value set is used to filter given values. +// Values besides predefined values will count as `unknown` method. +// `WithExtraMethods` can be used to add more methods to the set. The Observe +// method of the Observer in the ObserverVec is called with the response size in +// bytes. Partitioning happens by HTTP status code and/or HTTP method if the +// respective instance label names are present in the ObserverVec. For +// unpartitioned observations, use an ObserverVec with zero labels. Note that +// partitioning of Histograms is expensive and should be used judiciously. +// +// If the wrapped Handler does not set a status code, a status code of 200 is assumed. +// +// If the wrapped Handler panics, no values are reported. +// +// See the example for InstrumentHandlerDuration for example usage. +func InstrumentHandlerResponseSize(obs prometheus.ObserverVec, next http.Handler, opts ...Option) http.Handler { + hOpts := defaultOptions() + for _, o := range opts { + o.apply(hOpts) + } + + // Curry the observer with dynamic labels before checking the remaining labels. + code, method := checkLabels(obs.MustCurryWith(hOpts.emptyDynamicLabels())) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + d := newDelegator(w, nil) + next.ServeHTTP(d, r) + + l := labels(code, method, r.Method, d.Status(), hOpts.extraMethods...) + for label, resolve := range hOpts.extraLabelsFromCtx { + l[label] = resolve(r.Context()) + } + observeWithExemplar(obs.With(l), float64(d.Written()), hOpts.getExemplarFn(r.Context())) + }) +} + +// checkLabels returns whether the provided Collector has a non-const, +// non-curried label named "code" and/or "method". It panics if the provided +// Collector does not have a Desc or has more than one Desc or its Desc is +// invalid. It also panics if the Collector has any non-const, non-curried +// labels that are not named "code" or "method". +func checkLabels(c prometheus.Collector) (code, method bool) { + // TODO(beorn7): Remove this hacky way to check for instance labels + // once Descriptors can have their dimensionality queried. + var ( + desc *prometheus.Desc + m prometheus.Metric + pm dto.Metric + lvs []string + ) + + // Get the Desc from the Collector. + descc := make(chan *prometheus.Desc, 1) + c.Describe(descc) + + select { + case desc = <-descc: + default: + panic("no description provided by collector") + } + select { + case <-descc: + panic("more than one description provided by collector") + default: + } + + close(descc) + + // Make sure the Collector has a valid Desc by registering it with a + // temporary registry. + prometheus.NewRegistry().MustRegister(c) + + // Create a ConstMetric with the Desc. Since we don't know how many + // variable labels there are, try for as long as it needs. + for err := errors.New("dummy"); err != nil; lvs = append(lvs, magicString) { + m, err = prometheus.NewConstMetric(desc, prometheus.UntypedValue, 0, lvs...) + } + + // Write out the metric into a proto message and look at the labels. + // If the value is not the magicString, it is a constLabel, which doesn't interest us. + // If the label is curried, it doesn't interest us. + // In all other cases, only "code" or "method" is allowed. + if err := m.Write(&pm); err != nil { + panic("error checking metric for labels") + } + for _, label := range pm.Label { + name, value := label.GetName(), label.GetValue() + if value != magicString || isLabelCurried(c, name) { + continue + } + switch name { + case "code": + code = true + case "method": + method = true + default: + panic("metric partitioned with non-supported labels") + } + } + return +} + +func isLabelCurried(c prometheus.Collector, label string) bool { + // This is even hackier than the label test above. + // We essentially try to curry again and see if it works. + // But for that, we need to type-convert to the two + // types we use here, ObserverVec or *CounterVec. + switch v := c.(type) { + case *prometheus.CounterVec: + if _, err := v.CurryWith(prometheus.Labels{label: "dummy"}); err == nil { + return false + } + case prometheus.ObserverVec: + if _, err := v.CurryWith(prometheus.Labels{label: "dummy"}); err == nil { + return false + } + default: + panic("unsupported metric vec type") + } + return true +} + +// emptyLabels is a one-time allocation for non-partitioned metrics to avoid +// unnecessary allocations on each request. +var emptyLabels = prometheus.Labels{} + +func labels(code, method bool, reqMethod string, status int, extraMethods ...string) prometheus.Labels { + if !(code || method) { + return emptyLabels + } + labels := prometheus.Labels{} + + if code { + labels["code"] = sanitizeCode(status) + } + if method { + labels["method"] = sanitizeMethod(reqMethod, extraMethods...) + } + + return labels +} + +func computeApproximateRequestSize(r *http.Request) int { + s := 0 + if r.URL != nil { + s += len(r.URL.String()) + } + + s += len(r.Method) + s += len(r.Proto) + for name, values := range r.Header { + s += len(name) + for _, value := range values { + s += len(value) + } + } + s += len(r.Host) + + // N.B. r.Form and r.MultipartForm are assumed to be included in r.URL. + + if r.ContentLength != -1 { + s += int(r.ContentLength) + } + return s +} + +// If the wrapped http.Handler has a known method, it will be sanitized and returned. +// Otherwise, "unknown" will be returned. The known method list can be extended +// as needed by using extraMethods parameter. +func sanitizeMethod(m string, extraMethods ...string) string { + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods for + // the methods chosen as default. + switch m { + case "GET", "get": + return "get" + case "PUT", "put": + return "put" + case "HEAD", "head": + return "head" + case "POST", "post": + return "post" + case "DELETE", "delete": + return "delete" + case "CONNECT", "connect": + return "connect" + case "OPTIONS", "options": + return "options" + case "NOTIFY", "notify": + return "notify" + case "TRACE", "trace": + return "trace" + case "PATCH", "patch": + return "patch" + default: + for _, method := range extraMethods { + if strings.EqualFold(m, method) { + return strings.ToLower(m) + } + } + return "unknown" + } +} + +// If the wrapped http.Handler has not set a status code, i.e. the value is +// currently 0, sanitizeCode will return 200, for consistency with behavior in +// the stdlib. +func sanitizeCode(s int) string { + // See for accepted codes https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml + switch s { + case 100: + return "100" + case 101: + return "101" + + case 200, 0: + return "200" + case 201: + return "201" + case 202: + return "202" + case 203: + return "203" + case 204: + return "204" + case 205: + return "205" + case 206: + return "206" + + case 300: + return "300" + case 301: + return "301" + case 302: + return "302" + case 304: + return "304" + case 305: + return "305" + case 307: + return "307" + + case 400: + return "400" + case 401: + return "401" + case 402: + return "402" + case 403: + return "403" + case 404: + return "404" + case 405: + return "405" + case 406: + return "406" + case 407: + return "407" + case 408: + return "408" + case 409: + return "409" + case 410: + return "410" + case 411: + return "411" + case 412: + return "412" + case 413: + return "413" + case 414: + return "414" + case 415: + return "415" + case 416: + return "416" + case 417: + return "417" + case 418: + return "418" + + case 500: + return "500" + case 501: + return "501" + case 502: + return "502" + case 503: + return "503" + case 504: + return "504" + case 505: + return "505" + + case 428: + return "428" + case 429: + return "429" + case 431: + return "431" + case 511: + return "511" + + default: + if s >= 100 && s <= 599 { + return strconv.Itoa(s) + } + return "unknown" + } +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/registry.go b/vendor/github.com/prometheus/client_golang/prometheus/registry.go new file mode 100644 index 0000000000..44da9433be --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/registry.go @@ -0,0 +1,1075 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "bytes" + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "unicode/utf8" + + "github.com/prometheus/client_golang/prometheus/internal" + + "github.com/cespare/xxhash/v2" + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" + "google.golang.org/protobuf/proto" +) + +const ( + // Capacity for the channel to collect metrics and descriptors. + capMetricChan = 1000 + capDescChan = 10 +) + +// DefaultRegisterer and DefaultGatherer are the implementations of the +// Registerer and Gatherer interface a number of convenience functions in this +// package act on. Initially, both variables point to the same Registry, which +// has a process collector (currently on Linux only, see NewProcessCollector) +// and a Go collector (see NewGoCollector, in particular the note about +// stop-the-world implication with Go versions older than 1.9) already +// registered. This approach to keep default instances as global state mirrors +// the approach of other packages in the Go standard library. Note that there +// are caveats. Change the variables with caution and only if you understand the +// consequences. Users who want to avoid global state altogether should not use +// the convenience functions and act on custom instances instead. +var ( + defaultRegistry = NewRegistry() + DefaultRegisterer Registerer = defaultRegistry + DefaultGatherer Gatherer = defaultRegistry +) + +func init() { + MustRegister(NewProcessCollector(ProcessCollectorOpts{})) + MustRegister(NewGoCollector()) +} + +// NewRegistry creates a new vanilla Registry without any Collectors +// pre-registered. +func NewRegistry() *Registry { + return &Registry{ + collectorsByID: map[uint64]Collector{}, + descIDs: map[uint64]struct{}{}, + dimHashesByName: map[string]uint64{}, + } +} + +// NewPedanticRegistry returns a registry that checks during collection if each +// collected Metric is consistent with its reported Desc, and if the Desc has +// actually been registered with the registry. Unchecked Collectors (those whose +// Describe method does not yield any descriptors) are excluded from the check. +// +// Usually, a Registry will be happy as long as the union of all collected +// Metrics is consistent and valid even if some metrics are not consistent with +// their own Desc or a Desc provided by their registered Collector. Well-behaved +// Collectors and Metrics will only provide consistent Descs. This Registry is +// useful to test the implementation of Collectors and Metrics. +func NewPedanticRegistry() *Registry { + r := NewRegistry() + r.pedanticChecksEnabled = true + return r +} + +// Registerer is the interface for the part of a registry in charge of +// registering and unregistering. Users of custom registries should use +// Registerer as type for registration purposes (rather than the Registry type +// directly). In that way, they are free to use custom Registerer implementation +// (e.g. for testing purposes). +type Registerer interface { + // Register registers a new Collector to be included in metrics + // collection. It returns an error if the descriptors provided by the + // Collector are invalid or if they — in combination with descriptors of + // already registered Collectors — do not fulfill the consistency and + // uniqueness criteria described in the documentation of metric.Desc. + // + // If the provided Collector is equal to a Collector already registered + // (which includes the case of re-registering the same Collector), the + // returned error is an instance of AlreadyRegisteredError, which + // contains the previously registered Collector. + // + // A Collector whose Describe method does not yield any Desc is treated + // as unchecked. Registration will always succeed. No check for + // re-registering (see previous paragraph) is performed. Thus, the + // caller is responsible for not double-registering the same unchecked + // Collector, and for providing a Collector that will not cause + // inconsistent metrics on collection. (This would lead to scrape + // errors.) + Register(Collector) error + // MustRegister works like Register but registers any number of + // Collectors and panics upon the first registration that causes an + // error. + MustRegister(...Collector) + // Unregister unregisters the Collector that equals the Collector passed + // in as an argument. (Two Collectors are considered equal if their + // Describe method yields the same set of descriptors.) The function + // returns whether a Collector was unregistered. Note that an unchecked + // Collector cannot be unregistered (as its Describe method does not + // yield any descriptor). + // + // Note that even after unregistering, it will not be possible to + // register a new Collector that is inconsistent with the unregistered + // Collector, e.g. a Collector collecting metrics with the same name but + // a different help string. The rationale here is that the same registry + // instance must only collect consistent metrics throughout its + // lifetime. + Unregister(Collector) bool +} + +// Gatherer is the interface for the part of a registry in charge of gathering +// the collected metrics into a number of MetricFamilies. The Gatherer interface +// comes with the same general implication as described for the Registerer +// interface. +type Gatherer interface { + // Gather calls the Collect method of the registered Collectors and then + // gathers the collected metrics into a lexicographically sorted slice + // of uniquely named MetricFamily protobufs. Gather ensures that the + // returned slice is valid and self-consistent so that it can be used + // for valid exposition. As an exception to the strict consistency + // requirements described for metric.Desc, Gather will tolerate + // different sets of label names for metrics of the same metric family. + // + // Even if an error occurs, Gather attempts to gather as many metrics as + // possible. Hence, if a non-nil error is returned, the returned + // MetricFamily slice could be nil (in case of a fatal error that + // prevented any meaningful metric collection) or contain a number of + // MetricFamily protobufs, some of which might be incomplete, and some + // might be missing altogether. The returned error (which might be a + // MultiError) explains the details. Note that this is mostly useful for + // debugging purposes. If the gathered protobufs are to be used for + // exposition in actual monitoring, it is almost always better to not + // expose an incomplete result and instead disregard the returned + // MetricFamily protobufs in case the returned error is non-nil. + Gather() ([]*dto.MetricFamily, error) +} + +// Register registers the provided Collector with the DefaultRegisterer. +// +// Register is a shortcut for DefaultRegisterer.Register(c). See there for more +// details. +func Register(c Collector) error { + return DefaultRegisterer.Register(c) +} + +// MustRegister registers the provided Collectors with the DefaultRegisterer and +// panics if any error occurs. +// +// MustRegister is a shortcut for DefaultRegisterer.MustRegister(cs...). See +// there for more details. +func MustRegister(cs ...Collector) { + DefaultRegisterer.MustRegister(cs...) +} + +// Unregister removes the registration of the provided Collector from the +// DefaultRegisterer. +// +// Unregister is a shortcut for DefaultRegisterer.Unregister(c). See there for +// more details. +func Unregister(c Collector) bool { + return DefaultRegisterer.Unregister(c) +} + +// GathererFunc turns a function into a Gatherer. +type GathererFunc func() ([]*dto.MetricFamily, error) + +// Gather implements Gatherer. +func (gf GathererFunc) Gather() ([]*dto.MetricFamily, error) { + return gf() +} + +// AlreadyRegisteredError is returned by the Register method if the Collector to +// be registered has already been registered before, or a different Collector +// that collects the same metrics has been registered before. Registration fails +// in that case, but you can detect from the kind of error what has +// happened. The error contains fields for the existing Collector and the +// (rejected) new Collector that equals the existing one. This can be used to +// find out if an equal Collector has been registered before and switch over to +// using the old one, as demonstrated in the example. +type AlreadyRegisteredError struct { + ExistingCollector, NewCollector Collector +} + +func (err AlreadyRegisteredError) Error() string { + return "duplicate metrics collector registration attempted" +} + +// MultiError is a slice of errors implementing the error interface. It is used +// by a Gatherer to report multiple errors during MetricFamily gathering. +type MultiError []error + +// Error formats the contained errors as a bullet point list, preceded by the +// total number of errors. Note that this results in a multi-line string. +func (errs MultiError) Error() string { + if len(errs) == 0 { + return "" + } + buf := &bytes.Buffer{} + fmt.Fprintf(buf, "%d error(s) occurred:", len(errs)) + for _, err := range errs { + fmt.Fprintf(buf, "\n* %s", err) + } + return buf.String() +} + +// Append appends the provided error if it is not nil. +func (errs *MultiError) Append(err error) { + if err != nil { + *errs = append(*errs, err) + } +} + +// MaybeUnwrap returns nil if len(errs) is 0. It returns the first and only +// contained error as error if len(errs is 1). In all other cases, it returns +// the MultiError directly. This is helpful for returning a MultiError in a way +// that only uses the MultiError if needed. +func (errs MultiError) MaybeUnwrap() error { + switch len(errs) { + case 0: + return nil + case 1: + return errs[0] + default: + return errs + } +} + +// Registry registers Prometheus collectors, collects their metrics, and gathers +// them into MetricFamilies for exposition. It implements Registerer, Gatherer, +// and Collector. The zero value is not usable. Create instances with +// NewRegistry or NewPedanticRegistry. +// +// Registry implements Collector to allow it to be used for creating groups of +// metrics. See the Grouping example for how this can be done. +type Registry struct { + mtx sync.RWMutex + collectorsByID map[uint64]Collector // ID is a hash of the descIDs. + descIDs map[uint64]struct{} + dimHashesByName map[string]uint64 + uncheckedCollectors []Collector + pedanticChecksEnabled bool +} + +// Register implements Registerer. +func (r *Registry) Register(c Collector) error { + var ( + descChan = make(chan *Desc, capDescChan) + newDescIDs = map[uint64]struct{}{} + newDimHashesByName = map[string]uint64{} + collectorID uint64 // All desc IDs XOR'd together. + duplicateDescErr error + ) + go func() { + c.Describe(descChan) + close(descChan) + }() + r.mtx.Lock() + defer func() { + // Drain channel in case of premature return to not leak a goroutine. + for range descChan { + } + r.mtx.Unlock() + }() + // Conduct various tests... + for desc := range descChan { + + // Is the descriptor valid at all? + if desc.err != nil { + return fmt.Errorf("descriptor %s is invalid: %w", desc, desc.err) + } + + // Is the descID unique? + // (In other words: Is the fqName + constLabel combination unique?) + if _, exists := r.descIDs[desc.id]; exists { + duplicateDescErr = fmt.Errorf("descriptor %s already exists with the same fully-qualified name and const label values", desc) + } + // If it is not a duplicate desc in this collector, XOR it to + // the collectorID. (We allow duplicate descs within the same + // collector, but their existence must be a no-op.) + if _, exists := newDescIDs[desc.id]; !exists { + newDescIDs[desc.id] = struct{}{} + collectorID ^= desc.id + } + + // Are all the label names and the help string consistent with + // previous descriptors of the same name? + // First check existing descriptors... + if dimHash, exists := r.dimHashesByName[desc.fqName]; exists { + if dimHash != desc.dimHash { + return fmt.Errorf("a previously registered descriptor with the same fully-qualified name as %s has different label names or a different help string", desc) + } + } else { + // ...then check the new descriptors already seen. + if dimHash, exists := newDimHashesByName[desc.fqName]; exists { + if dimHash != desc.dimHash { + return fmt.Errorf("descriptors reported by collector have inconsistent label names or help strings for the same fully-qualified name, offender is %s", desc) + } + } else { + newDimHashesByName[desc.fqName] = desc.dimHash + } + } + } + // A Collector yielding no Desc at all is considered unchecked. + if len(newDescIDs) == 0 { + r.uncheckedCollectors = append(r.uncheckedCollectors, c) + return nil + } + if existing, exists := r.collectorsByID[collectorID]; exists { + switch e := existing.(type) { + case *wrappingCollector: + return AlreadyRegisteredError{ + ExistingCollector: e.unwrapRecursively(), + NewCollector: c, + } + default: + return AlreadyRegisteredError{ + ExistingCollector: e, + NewCollector: c, + } + } + } + // If the collectorID is new, but at least one of the descs existed + // before, we are in trouble. + if duplicateDescErr != nil { + return duplicateDescErr + } + + // Only after all tests have passed, actually register. + r.collectorsByID[collectorID] = c + for hash := range newDescIDs { + r.descIDs[hash] = struct{}{} + } + for name, dimHash := range newDimHashesByName { + r.dimHashesByName[name] = dimHash + } + return nil +} + +// Unregister implements Registerer. +func (r *Registry) Unregister(c Collector) bool { + var ( + descChan = make(chan *Desc, capDescChan) + descIDs = map[uint64]struct{}{} + collectorID uint64 // All desc IDs XOR'd together. + ) + go func() { + c.Describe(descChan) + close(descChan) + }() + for desc := range descChan { + if _, exists := descIDs[desc.id]; !exists { + collectorID ^= desc.id + descIDs[desc.id] = struct{}{} + } + } + + r.mtx.RLock() + if _, exists := r.collectorsByID[collectorID]; !exists { + r.mtx.RUnlock() + return false + } + r.mtx.RUnlock() + + r.mtx.Lock() + defer r.mtx.Unlock() + + delete(r.collectorsByID, collectorID) + for id := range descIDs { + delete(r.descIDs, id) + } + // dimHashesByName is left untouched as those must be consistent + // throughout the lifetime of a program. + return true +} + +// MustRegister implements Registerer. +func (r *Registry) MustRegister(cs ...Collector) { + for _, c := range cs { + if err := r.Register(c); err != nil { + panic(err) + } + } +} + +// Gather implements Gatherer. +func (r *Registry) Gather() ([]*dto.MetricFamily, error) { + r.mtx.RLock() + + if len(r.collectorsByID) == 0 && len(r.uncheckedCollectors) == 0 { + // Fast path. + r.mtx.RUnlock() + return nil, nil + } + + var ( + checkedMetricChan = make(chan Metric, capMetricChan) + uncheckedMetricChan = make(chan Metric, capMetricChan) + metricHashes = map[uint64]struct{}{} + wg sync.WaitGroup + errs MultiError // The collected errors to return in the end. + registeredDescIDs map[uint64]struct{} // Only used for pedantic checks + ) + + goroutineBudget := len(r.collectorsByID) + len(r.uncheckedCollectors) + metricFamiliesByName := make(map[string]*dto.MetricFamily, len(r.dimHashesByName)) + checkedCollectors := make(chan Collector, len(r.collectorsByID)) + uncheckedCollectors := make(chan Collector, len(r.uncheckedCollectors)) + for _, collector := range r.collectorsByID { + checkedCollectors <- collector + } + for _, collector := range r.uncheckedCollectors { + uncheckedCollectors <- collector + } + // In case pedantic checks are enabled, we have to copy the map before + // giving up the RLock. + if r.pedanticChecksEnabled { + registeredDescIDs = make(map[uint64]struct{}, len(r.descIDs)) + for id := range r.descIDs { + registeredDescIDs[id] = struct{}{} + } + } + r.mtx.RUnlock() + + wg.Add(goroutineBudget) + + collectWorker := func() { + for { + select { + case collector := <-checkedCollectors: + collector.Collect(checkedMetricChan) + case collector := <-uncheckedCollectors: + collector.Collect(uncheckedMetricChan) + default: + return + } + wg.Done() + } + } + + // Start the first worker now to make sure at least one is running. + go collectWorker() + goroutineBudget-- + + // Close checkedMetricChan and uncheckedMetricChan once all collectors + // are collected. + go func() { + wg.Wait() + close(checkedMetricChan) + close(uncheckedMetricChan) + }() + + // Drain checkedMetricChan and uncheckedMetricChan in case of premature return. + defer func() { + if checkedMetricChan != nil { + for range checkedMetricChan { + } + } + if uncheckedMetricChan != nil { + for range uncheckedMetricChan { + } + } + }() + + // Copy the channel references so we can nil them out later to remove + // them from the select statements below. + cmc := checkedMetricChan + umc := uncheckedMetricChan + + for { + select { + case metric, ok := <-cmc: + if !ok { + cmc = nil + break + } + errs.Append(processMetric( + metric, metricFamiliesByName, + metricHashes, + registeredDescIDs, + )) + case metric, ok := <-umc: + if !ok { + umc = nil + break + } + errs.Append(processMetric( + metric, metricFamiliesByName, + metricHashes, + nil, + )) + default: + if goroutineBudget <= 0 || len(checkedCollectors)+len(uncheckedCollectors) == 0 { + // All collectors are already being worked on or + // we have already as many goroutines started as + // there are collectors. Do the same as above, + // just without the default. + select { + case metric, ok := <-cmc: + if !ok { + cmc = nil + break + } + errs.Append(processMetric( + metric, metricFamiliesByName, + metricHashes, + registeredDescIDs, + )) + case metric, ok := <-umc: + if !ok { + umc = nil + break + } + errs.Append(processMetric( + metric, metricFamiliesByName, + metricHashes, + nil, + )) + } + break + } + // Start more workers. + go collectWorker() + goroutineBudget-- + runtime.Gosched() + } + // Once both checkedMetricChan and uncheckdMetricChan are closed + // and drained, the contraption above will nil out cmc and umc, + // and then we can leave the collect loop here. + if cmc == nil && umc == nil { + break + } + } + return internal.NormalizeMetricFamilies(metricFamiliesByName), errs.MaybeUnwrap() +} + +// Describe implements Collector. +func (r *Registry) Describe(ch chan<- *Desc) { + r.mtx.RLock() + defer r.mtx.RUnlock() + + // Only report the checked Collectors; unchecked collectors don't report any + // Desc. + for _, c := range r.collectorsByID { + c.Describe(ch) + } +} + +// Collect implements Collector. +func (r *Registry) Collect(ch chan<- Metric) { + r.mtx.RLock() + defer r.mtx.RUnlock() + + for _, c := range r.collectorsByID { + c.Collect(ch) + } + for _, c := range r.uncheckedCollectors { + c.Collect(ch) + } +} + +// WriteToTextfile calls Gather on the provided Gatherer, encodes the result in the +// Prometheus text format, and writes it to a temporary file. Upon success, the +// temporary file is renamed to the provided filename. +// +// This is intended for use with the textfile collector of the node exporter. +// Note that the node exporter expects the filename to be suffixed with ".prom". +func WriteToTextfile(filename string, g Gatherer) error { + tmp, err := os.CreateTemp(filepath.Dir(filename), filepath.Base(filename)) + if err != nil { + return err + } + defer os.Remove(tmp.Name()) + + mfs, err := g.Gather() + if err != nil { + return err + } + for _, mf := range mfs { + if _, err := expfmt.MetricFamilyToText(tmp, mf); err != nil { + return err + } + } + if err := tmp.Close(); err != nil { + return err + } + + if err := os.Chmod(tmp.Name(), 0o644); err != nil { + return err + } + return os.Rename(tmp.Name(), filename) +} + +// processMetric is an internal helper method only used by the Gather method. +func processMetric( + metric Metric, + metricFamiliesByName map[string]*dto.MetricFamily, + metricHashes map[uint64]struct{}, + registeredDescIDs map[uint64]struct{}, +) error { + desc := metric.Desc() + // Wrapped metrics collected by an unchecked Collector can have an + // invalid Desc. + if desc.err != nil { + return desc.err + } + dtoMetric := &dto.Metric{} + if err := metric.Write(dtoMetric); err != nil { + return fmt.Errorf("error collecting metric %v: %w", desc, err) + } + metricFamily, ok := metricFamiliesByName[desc.fqName] + if ok { // Existing name. + if metricFamily.GetHelp() != desc.help { + return fmt.Errorf( + "collected metric %s %s has help %q but should have %q", + desc.fqName, dtoMetric, desc.help, metricFamily.GetHelp(), + ) + } + // TODO(beorn7): Simplify switch once Desc has type. + switch metricFamily.GetType() { + case dto.MetricType_COUNTER: + if dtoMetric.Counter == nil { + return fmt.Errorf( + "collected metric %s %s should be a Counter", + desc.fqName, dtoMetric, + ) + } + case dto.MetricType_GAUGE: + if dtoMetric.Gauge == nil { + return fmt.Errorf( + "collected metric %s %s should be a Gauge", + desc.fqName, dtoMetric, + ) + } + case dto.MetricType_SUMMARY: + if dtoMetric.Summary == nil { + return fmt.Errorf( + "collected metric %s %s should be a Summary", + desc.fqName, dtoMetric, + ) + } + case dto.MetricType_UNTYPED: + if dtoMetric.Untyped == nil { + return fmt.Errorf( + "collected metric %s %s should be Untyped", + desc.fqName, dtoMetric, + ) + } + case dto.MetricType_HISTOGRAM: + if dtoMetric.Histogram == nil { + return fmt.Errorf( + "collected metric %s %s should be a Histogram", + desc.fqName, dtoMetric, + ) + } + default: + panic("encountered MetricFamily with invalid type") + } + } else { // New name. + metricFamily = &dto.MetricFamily{} + metricFamily.Name = proto.String(desc.fqName) + metricFamily.Help = proto.String(desc.help) + // TODO(beorn7): Simplify switch once Desc has type. + switch { + case dtoMetric.Gauge != nil: + metricFamily.Type = dto.MetricType_GAUGE.Enum() + case dtoMetric.Counter != nil: + metricFamily.Type = dto.MetricType_COUNTER.Enum() + case dtoMetric.Summary != nil: + metricFamily.Type = dto.MetricType_SUMMARY.Enum() + case dtoMetric.Untyped != nil: + metricFamily.Type = dto.MetricType_UNTYPED.Enum() + case dtoMetric.Histogram != nil: + metricFamily.Type = dto.MetricType_HISTOGRAM.Enum() + default: + return fmt.Errorf("empty metric collected: %s", dtoMetric) + } + if err := checkSuffixCollisions(metricFamily, metricFamiliesByName); err != nil { + return err + } + metricFamiliesByName[desc.fqName] = metricFamily + } + if err := checkMetricConsistency(metricFamily, dtoMetric, metricHashes); err != nil { + return err + } + if registeredDescIDs != nil { + // Is the desc registered at all? + if _, exist := registeredDescIDs[desc.id]; !exist { + return fmt.Errorf( + "collected metric %s %s with unregistered descriptor %s", + metricFamily.GetName(), dtoMetric, desc, + ) + } + if err := checkDescConsistency(metricFamily, dtoMetric, desc); err != nil { + return err + } + } + metricFamily.Metric = append(metricFamily.Metric, dtoMetric) + return nil +} + +// Gatherers is a slice of Gatherer instances that implements the Gatherer +// interface itself. Its Gather method calls Gather on all Gatherers in the +// slice in order and returns the merged results. Errors returned from the +// Gather calls are all returned in a flattened MultiError. Duplicate and +// inconsistent Metrics are skipped (first occurrence in slice order wins) and +// reported in the returned error. +// +// Gatherers can be used to merge the Gather results from multiple +// Registries. It also provides a way to directly inject existing MetricFamily +// protobufs into the gathering by creating a custom Gatherer with a Gather +// method that simply returns the existing MetricFamily protobufs. Note that no +// registration is involved (in contrast to Collector registration), so +// obviously registration-time checks cannot happen. Any inconsistencies between +// the gathered MetricFamilies are reported as errors by the Gather method, and +// inconsistent Metrics are dropped. Invalid parts of the MetricFamilies +// (e.g. syntactically invalid metric or label names) will go undetected. +type Gatherers []Gatherer + +// Gather implements Gatherer. +func (gs Gatherers) Gather() ([]*dto.MetricFamily, error) { + var ( + metricFamiliesByName = map[string]*dto.MetricFamily{} + metricHashes = map[uint64]struct{}{} + errs MultiError // The collected errors to return in the end. + ) + + for i, g := range gs { + mfs, err := g.Gather() + if err != nil { + multiErr := MultiError{} + if errors.As(err, &multiErr) { + for _, err := range multiErr { + errs = append(errs, fmt.Errorf("[from Gatherer #%d] %w", i+1, err)) + } + } else { + errs = append(errs, fmt.Errorf("[from Gatherer #%d] %w", i+1, err)) + } + } + for _, mf := range mfs { + existingMF, exists := metricFamiliesByName[mf.GetName()] + if exists { + if existingMF.GetHelp() != mf.GetHelp() { + errs = append(errs, fmt.Errorf( + "gathered metric family %s has help %q but should have %q", + mf.GetName(), mf.GetHelp(), existingMF.GetHelp(), + )) + continue + } + if existingMF.GetType() != mf.GetType() { + errs = append(errs, fmt.Errorf( + "gathered metric family %s has type %s but should have %s", + mf.GetName(), mf.GetType(), existingMF.GetType(), + )) + continue + } + } else { + existingMF = &dto.MetricFamily{} + existingMF.Name = mf.Name + existingMF.Help = mf.Help + existingMF.Type = mf.Type + if err := checkSuffixCollisions(existingMF, metricFamiliesByName); err != nil { + errs = append(errs, err) + continue + } + metricFamiliesByName[mf.GetName()] = existingMF + } + for _, m := range mf.Metric { + if err := checkMetricConsistency(existingMF, m, metricHashes); err != nil { + errs = append(errs, err) + continue + } + existingMF.Metric = append(existingMF.Metric, m) + } + } + } + return internal.NormalizeMetricFamilies(metricFamiliesByName), errs.MaybeUnwrap() +} + +// checkSuffixCollisions checks for collisions with the “magic” suffixes the +// Prometheus text format and the internal metric representation of the +// Prometheus server add while flattening Summaries and Histograms. +func checkSuffixCollisions(mf *dto.MetricFamily, mfs map[string]*dto.MetricFamily) error { + var ( + newName = mf.GetName() + newType = mf.GetType() + newNameWithoutSuffix = "" + ) + switch { + case strings.HasSuffix(newName, "_count"): + newNameWithoutSuffix = newName[:len(newName)-6] + case strings.HasSuffix(newName, "_sum"): + newNameWithoutSuffix = newName[:len(newName)-4] + case strings.HasSuffix(newName, "_bucket"): + newNameWithoutSuffix = newName[:len(newName)-7] + } + if newNameWithoutSuffix != "" { + if existingMF, ok := mfs[newNameWithoutSuffix]; ok { + switch existingMF.GetType() { + case dto.MetricType_SUMMARY: + if !strings.HasSuffix(newName, "_bucket") { + return fmt.Errorf( + "collected metric named %q collides with previously collected summary named %q", + newName, newNameWithoutSuffix, + ) + } + case dto.MetricType_HISTOGRAM: + return fmt.Errorf( + "collected metric named %q collides with previously collected histogram named %q", + newName, newNameWithoutSuffix, + ) + } + } + } + if newType == dto.MetricType_SUMMARY || newType == dto.MetricType_HISTOGRAM { + if _, ok := mfs[newName+"_count"]; ok { + return fmt.Errorf( + "collected histogram or summary named %q collides with previously collected metric named %q", + newName, newName+"_count", + ) + } + if _, ok := mfs[newName+"_sum"]; ok { + return fmt.Errorf( + "collected histogram or summary named %q collides with previously collected metric named %q", + newName, newName+"_sum", + ) + } + } + if newType == dto.MetricType_HISTOGRAM { + if _, ok := mfs[newName+"_bucket"]; ok { + return fmt.Errorf( + "collected histogram named %q collides with previously collected metric named %q", + newName, newName+"_bucket", + ) + } + } + return nil +} + +// checkMetricConsistency checks if the provided Metric is consistent with the +// provided MetricFamily. It also hashes the Metric labels and the MetricFamily +// name. If the resulting hash is already in the provided metricHashes, an error +// is returned. If not, it is added to metricHashes. +func checkMetricConsistency( + metricFamily *dto.MetricFamily, + dtoMetric *dto.Metric, + metricHashes map[uint64]struct{}, +) error { + name := metricFamily.GetName() + + // Type consistency with metric family. + if metricFamily.GetType() == dto.MetricType_GAUGE && dtoMetric.Gauge == nil || + metricFamily.GetType() == dto.MetricType_COUNTER && dtoMetric.Counter == nil || + metricFamily.GetType() == dto.MetricType_SUMMARY && dtoMetric.Summary == nil || + metricFamily.GetType() == dto.MetricType_HISTOGRAM && dtoMetric.Histogram == nil || + metricFamily.GetType() == dto.MetricType_UNTYPED && dtoMetric.Untyped == nil { + return fmt.Errorf( + "collected metric %q { %s} is not a %s", + name, dtoMetric, metricFamily.GetType(), + ) + } + + previousLabelName := "" + for _, labelPair := range dtoMetric.GetLabel() { + labelName := labelPair.GetName() + if labelName == previousLabelName { + return fmt.Errorf( + "collected metric %q { %s} has two or more labels with the same name: %s", + name, dtoMetric, labelName, + ) + } + if !checkLabelName(labelName) { + return fmt.Errorf( + "collected metric %q { %s} has a label with an invalid name: %s", + name, dtoMetric, labelName, + ) + } + if dtoMetric.Summary != nil && labelName == quantileLabel { + return fmt.Errorf( + "collected metric %q { %s} must not have an explicit %q label", + name, dtoMetric, quantileLabel, + ) + } + if !utf8.ValidString(labelPair.GetValue()) { + return fmt.Errorf( + "collected metric %q { %s} has a label named %q whose value is not utf8: %#v", + name, dtoMetric, labelName, labelPair.GetValue()) + } + previousLabelName = labelName + } + + // Is the metric unique (i.e. no other metric with the same name and the same labels)? + h := xxhash.New() + h.WriteString(name) + h.Write(separatorByteSlice) + // Make sure label pairs are sorted. We depend on it for the consistency + // check. + if !sort.IsSorted(internal.LabelPairSorter(dtoMetric.Label)) { + // We cannot sort dtoMetric.Label in place as it is immutable by contract. + copiedLabels := make([]*dto.LabelPair, len(dtoMetric.Label)) + copy(copiedLabels, dtoMetric.Label) + sort.Sort(internal.LabelPairSorter(copiedLabels)) + dtoMetric.Label = copiedLabels + } + for _, lp := range dtoMetric.Label { + h.WriteString(lp.GetName()) + h.Write(separatorByteSlice) + h.WriteString(lp.GetValue()) + h.Write(separatorByteSlice) + } + if dtoMetric.TimestampMs != nil { + h.WriteString(strconv.FormatInt(*(dtoMetric.TimestampMs), 10)) + h.Write(separatorByteSlice) + } + hSum := h.Sum64() + if _, exists := metricHashes[hSum]; exists { + return fmt.Errorf( + "collected metric %q { %s} was collected before with the same name and label values", + name, dtoMetric, + ) + } + metricHashes[hSum] = struct{}{} + return nil +} + +func checkDescConsistency( + metricFamily *dto.MetricFamily, + dtoMetric *dto.Metric, + desc *Desc, +) error { + // Desc help consistency with metric family help. + if metricFamily.GetHelp() != desc.help { + return fmt.Errorf( + "collected metric %s %s has help %q but should have %q", + metricFamily.GetName(), dtoMetric, metricFamily.GetHelp(), desc.help, + ) + } + + // Is the desc consistent with the content of the metric? + lpsFromDesc := make([]*dto.LabelPair, len(desc.constLabelPairs), len(dtoMetric.Label)) + copy(lpsFromDesc, desc.constLabelPairs) + for _, l := range desc.variableLabels { + lpsFromDesc = append(lpsFromDesc, &dto.LabelPair{ + Name: proto.String(l.Name), + }) + } + if len(lpsFromDesc) != len(dtoMetric.Label) { + return fmt.Errorf( + "labels in collected metric %s %s are inconsistent with descriptor %s", + metricFamily.GetName(), dtoMetric, desc, + ) + } + sort.Sort(internal.LabelPairSorter(lpsFromDesc)) + for i, lpFromDesc := range lpsFromDesc { + lpFromMetric := dtoMetric.Label[i] + if lpFromDesc.GetName() != lpFromMetric.GetName() || + lpFromDesc.Value != nil && lpFromDesc.GetValue() != lpFromMetric.GetValue() { + return fmt.Errorf( + "labels in collected metric %s %s are inconsistent with descriptor %s", + metricFamily.GetName(), dtoMetric, desc, + ) + } + } + return nil +} + +var _ TransactionalGatherer = &MultiTRegistry{} + +// MultiTRegistry is a TransactionalGatherer that joins gathered metrics from multiple +// transactional gatherers. +// +// It is caller responsibility to ensure two registries have mutually exclusive metric families, +// no deduplication will happen. +type MultiTRegistry struct { + tGatherers []TransactionalGatherer +} + +// NewMultiTRegistry creates MultiTRegistry. +func NewMultiTRegistry(tGatherers ...TransactionalGatherer) *MultiTRegistry { + return &MultiTRegistry{ + tGatherers: tGatherers, + } +} + +// Gather implements TransactionalGatherer interface. +func (r *MultiTRegistry) Gather() (mfs []*dto.MetricFamily, done func(), err error) { + errs := MultiError{} + + dFns := make([]func(), 0, len(r.tGatherers)) + // TODO(bwplotka): Implement concurrency for those? + for _, g := range r.tGatherers { + // TODO(bwplotka): Check for duplicates? + m, d, err := g.Gather() + errs.Append(err) + + mfs = append(mfs, m...) + dFns = append(dFns, d) + } + + // TODO(bwplotka): Consider sort in place, given metric family in gather is sorted already. + sort.Slice(mfs, func(i, j int) bool { + return *mfs[i].Name < *mfs[j].Name + }) + return mfs, func() { + for _, d := range dFns { + d() + } + }, errs.MaybeUnwrap() +} + +// TransactionalGatherer represents transactional gatherer that can be triggered to notify gatherer that memory +// used by metric family is no longer used by a caller. This allows implementations with cache. +type TransactionalGatherer interface { + // Gather returns metrics in a lexicographically sorted slice + // of uniquely named MetricFamily protobufs. Gather ensures that the + // returned slice is valid and self-consistent so that it can be used + // for valid exposition. As an exception to the strict consistency + // requirements described for metric.Desc, Gather will tolerate + // different sets of label names for metrics of the same metric family. + // + // Even if an error occurs, Gather attempts to gather as many metrics as + // possible. Hence, if a non-nil error is returned, the returned + // MetricFamily slice could be nil (in case of a fatal error that + // prevented any meaningful metric collection) or contain a number of + // MetricFamily protobufs, some of which might be incomplete, and some + // might be missing altogether. The returned error (which might be a + // MultiError) explains the details. Note that this is mostly useful for + // debugging purposes. If the gathered protobufs are to be used for + // exposition in actual monitoring, it is almost always better to not + // expose an incomplete result and instead disregard the returned + // MetricFamily protobufs in case the returned error is non-nil. + // + // Important: done is expected to be triggered (even if the error occurs!) + // once caller does not need returned slice of dto.MetricFamily. + Gather() (_ []*dto.MetricFamily, done func(), err error) +} + +// ToTransactionalGatherer transforms Gatherer to transactional one with noop as done function. +func ToTransactionalGatherer(g Gatherer) TransactionalGatherer { + return &noTransactionGatherer{g: g} +} + +type noTransactionGatherer struct { + g Gatherer +} + +// Gather implements TransactionalGatherer interface. +func (g *noTransactionGatherer) Gather() (_ []*dto.MetricFamily, done func(), err error) { + mfs, err := g.g.Gather() + return mfs, func() {}, err +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/summary.go b/vendor/github.com/prometheus/client_golang/prometheus/summary.go new file mode 100644 index 0000000000..dd359264e5 --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/summary.go @@ -0,0 +1,766 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "fmt" + "math" + "runtime" + "sort" + "sync" + "sync/atomic" + "time" + + dto "github.com/prometheus/client_model/go" + + "github.com/beorn7/perks/quantile" + "google.golang.org/protobuf/proto" +) + +// quantileLabel is used for the label that defines the quantile in a +// summary. +const quantileLabel = "quantile" + +// A Summary captures individual observations from an event or sample stream and +// summarizes them in a manner similar to traditional summary statistics: 1. sum +// of observations, 2. observation count, 3. rank estimations. +// +// A typical use-case is the observation of request latencies. By default, a +// Summary provides the median, the 90th and the 99th percentile of the latency +// as rank estimations. However, the default behavior will change in the +// upcoming v1.0.0 of the library. There will be no rank estimations at all by +// default. For a sane transition, it is recommended to set the desired rank +// estimations explicitly. +// +// Note that the rank estimations cannot be aggregated in a meaningful way with +// the Prometheus query language (i.e. you cannot average or add them). If you +// need aggregatable quantiles (e.g. you want the 99th percentile latency of all +// queries served across all instances of a service), consider the Histogram +// metric type. See the Prometheus documentation for more details. +// +// To create Summary instances, use NewSummary. +type Summary interface { + Metric + Collector + + // Observe adds a single observation to the summary. Observations are + // usually positive or zero. Negative observations are accepted but + // prevent current versions of Prometheus from properly detecting + // counter resets in the sum of observations. See + // https://prometheus.io/docs/practices/histograms/#count-and-sum-of-observations + // for details. + Observe(float64) +} + +var errQuantileLabelNotAllowed = fmt.Errorf( + "%q is not allowed as label name in summaries", quantileLabel, +) + +// Default values for SummaryOpts. +const ( + // DefMaxAge is the default duration for which observations stay + // relevant. + DefMaxAge time.Duration = 10 * time.Minute + // DefAgeBuckets is the default number of buckets used to calculate the + // age of observations. + DefAgeBuckets = 5 + // DefBufCap is the standard buffer size for collecting Summary observations. + DefBufCap = 500 +) + +// SummaryOpts bundles the options for creating a Summary metric. It is +// mandatory to set Name to a non-empty string. While all other fields are +// optional and can safely be left at their zero value, it is recommended to set +// a help string and to explicitly set the Objectives field to the desired value +// as the default value will change in the upcoming v1.0.0 of the library. +type SummaryOpts struct { + // Namespace, Subsystem, and Name are components of the fully-qualified + // name of the Summary (created by joining these components with + // "_"). Only Name is mandatory, the others merely help structuring the + // name. Note that the fully-qualified name of the Summary must be a + // valid Prometheus metric name. + Namespace string + Subsystem string + Name string + + // Help provides information about this Summary. + // + // Metrics with the same fully-qualified name must have the same Help + // string. + Help string + + // ConstLabels are used to attach fixed labels to this metric. Metrics + // with the same fully-qualified name must have the same label names in + // their ConstLabels. + // + // Due to the way a Summary is represented in the Prometheus text format + // and how it is handled by the Prometheus server internally, “quantile” + // is an illegal label name. Construction of a Summary or SummaryVec + // will panic if this label name is used in ConstLabels. + // + // ConstLabels are only used rarely. In particular, do not use them to + // attach the same labels to all your metrics. Those use cases are + // better covered by target labels set by the scraping Prometheus + // server, or by one specific metric (e.g. a build_info or a + // machine_role metric). See also + // https://prometheus.io/docs/instrumenting/writing_exporters/#target-labels-not-static-scraped-labels + ConstLabels Labels + + // Objectives defines the quantile rank estimates with their respective + // absolute error. If Objectives[q] = e, then the value reported for q + // will be the φ-quantile value for some φ between q-e and q+e. The + // default value is an empty map, resulting in a summary without + // quantiles. + Objectives map[float64]float64 + + // MaxAge defines the duration for which an observation stays relevant + // for the summary. Only applies to pre-calculated quantiles, does not + // apply to _sum and _count. Must be positive. The default value is + // DefMaxAge. + MaxAge time.Duration + + // AgeBuckets is the number of buckets used to exclude observations that + // are older than MaxAge from the summary. A higher number has a + // resource penalty, so only increase it if the higher resolution is + // really required. For very high observation rates, you might want to + // reduce the number of age buckets. With only one age bucket, you will + // effectively see a complete reset of the summary each time MaxAge has + // passed. The default value is DefAgeBuckets. + AgeBuckets uint32 + + // BufCap defines the default sample stream buffer size. The default + // value of DefBufCap should suffice for most uses. If there is a need + // to increase the value, a multiple of 500 is recommended (because that + // is the internal buffer size of the underlying package + // "github.com/bmizerany/perks/quantile"). + BufCap uint32 +} + +// SummaryVecOpts bundles the options to create a SummaryVec metric. +// It is mandatory to set SummaryOpts, see there for mandatory fields. VariableLabels +// is optional and can safely be left to its default value. +type SummaryVecOpts struct { + SummaryOpts + + // VariableLabels are used to partition the metric vector by the given set + // of labels. Each label value will be constrained with the optional Contraint + // function, if provided. + VariableLabels ConstrainableLabels +} + +// Problem with the sliding-window decay algorithm... The Merge method of +// perk/quantile is actually not working as advertised - and it might be +// unfixable, as the underlying algorithm is apparently not capable of merging +// summaries in the first place. To avoid using Merge, we are currently adding +// observations to _each_ age bucket, i.e. the effort to add a sample is +// essentially multiplied by the number of age buckets. When rotating age +// buckets, we empty the previous head stream. On scrape time, we simply take +// the quantiles from the head stream (no merging required). Result: More effort +// on observation time, less effort on scrape time, which is exactly the +// opposite of what we try to accomplish, but at least the results are correct. +// +// The quite elegant previous contraption to merge the age buckets efficiently +// on scrape time (see code up commit 6b9530d72ea715f0ba612c0120e6e09fbf1d49d0) +// can't be used anymore. + +// NewSummary creates a new Summary based on the provided SummaryOpts. +func NewSummary(opts SummaryOpts) Summary { + return newSummary( + NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + nil, + opts.ConstLabels, + ), + opts, + ) +} + +func newSummary(desc *Desc, opts SummaryOpts, labelValues ...string) Summary { + if len(desc.variableLabels) != len(labelValues) { + panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels.labelNames(), labelValues)) + } + + for _, n := range desc.variableLabels { + if n.Name == quantileLabel { + panic(errQuantileLabelNotAllowed) + } + } + for _, lp := range desc.constLabelPairs { + if lp.GetName() == quantileLabel { + panic(errQuantileLabelNotAllowed) + } + } + + if opts.Objectives == nil { + opts.Objectives = map[float64]float64{} + } + + if opts.MaxAge < 0 { + panic(fmt.Errorf("illegal max age MaxAge=%v", opts.MaxAge)) + } + if opts.MaxAge == 0 { + opts.MaxAge = DefMaxAge + } + + if opts.AgeBuckets == 0 { + opts.AgeBuckets = DefAgeBuckets + } + + if opts.BufCap == 0 { + opts.BufCap = DefBufCap + } + + if len(opts.Objectives) == 0 { + // Use the lock-free implementation of a Summary without objectives. + s := &noObjectivesSummary{ + desc: desc, + labelPairs: MakeLabelPairs(desc, labelValues), + counts: [2]*summaryCounts{{}, {}}, + } + s.init(s) // Init self-collection. + return s + } + + s := &summary{ + desc: desc, + + objectives: opts.Objectives, + sortedObjectives: make([]float64, 0, len(opts.Objectives)), + + labelPairs: MakeLabelPairs(desc, labelValues), + + hotBuf: make([]float64, 0, opts.BufCap), + coldBuf: make([]float64, 0, opts.BufCap), + streamDuration: opts.MaxAge / time.Duration(opts.AgeBuckets), + } + s.headStreamExpTime = time.Now().Add(s.streamDuration) + s.hotBufExpTime = s.headStreamExpTime + + for i := uint32(0); i < opts.AgeBuckets; i++ { + s.streams = append(s.streams, s.newStream()) + } + s.headStream = s.streams[0] + + for qu := range s.objectives { + s.sortedObjectives = append(s.sortedObjectives, qu) + } + sort.Float64s(s.sortedObjectives) + + s.init(s) // Init self-collection. + return s +} + +type summary struct { + selfCollector + + bufMtx sync.Mutex // Protects hotBuf and hotBufExpTime. + mtx sync.Mutex // Protects every other moving part. + // Lock bufMtx before mtx if both are needed. + + desc *Desc + + objectives map[float64]float64 + sortedObjectives []float64 + + labelPairs []*dto.LabelPair + + sum float64 + cnt uint64 + + hotBuf, coldBuf []float64 + + streams []*quantile.Stream + streamDuration time.Duration + headStream *quantile.Stream + headStreamIdx int + headStreamExpTime, hotBufExpTime time.Time +} + +func (s *summary) Desc() *Desc { + return s.desc +} + +func (s *summary) Observe(v float64) { + s.bufMtx.Lock() + defer s.bufMtx.Unlock() + + now := time.Now() + if now.After(s.hotBufExpTime) { + s.asyncFlush(now) + } + s.hotBuf = append(s.hotBuf, v) + if len(s.hotBuf) == cap(s.hotBuf) { + s.asyncFlush(now) + } +} + +func (s *summary) Write(out *dto.Metric) error { + sum := &dto.Summary{} + qs := make([]*dto.Quantile, 0, len(s.objectives)) + + s.bufMtx.Lock() + s.mtx.Lock() + // Swap bufs even if hotBuf is empty to set new hotBufExpTime. + s.swapBufs(time.Now()) + s.bufMtx.Unlock() + + s.flushColdBuf() + sum.SampleCount = proto.Uint64(s.cnt) + sum.SampleSum = proto.Float64(s.sum) + + for _, rank := range s.sortedObjectives { + var q float64 + if s.headStream.Count() == 0 { + q = math.NaN() + } else { + q = s.headStream.Query(rank) + } + qs = append(qs, &dto.Quantile{ + Quantile: proto.Float64(rank), + Value: proto.Float64(q), + }) + } + + s.mtx.Unlock() + + if len(qs) > 0 { + sort.Sort(quantSort(qs)) + } + sum.Quantile = qs + + out.Summary = sum + out.Label = s.labelPairs + return nil +} + +func (s *summary) newStream() *quantile.Stream { + return quantile.NewTargeted(s.objectives) +} + +// asyncFlush needs bufMtx locked. +func (s *summary) asyncFlush(now time.Time) { + s.mtx.Lock() + s.swapBufs(now) + + // Unblock the original goroutine that was responsible for the mutation + // that triggered the compaction. But hold onto the global non-buffer + // state mutex until the operation finishes. + go func() { + s.flushColdBuf() + s.mtx.Unlock() + }() +} + +// rotateStreams needs mtx AND bufMtx locked. +func (s *summary) maybeRotateStreams() { + for !s.hotBufExpTime.Equal(s.headStreamExpTime) { + s.headStream.Reset() + s.headStreamIdx++ + if s.headStreamIdx >= len(s.streams) { + s.headStreamIdx = 0 + } + s.headStream = s.streams[s.headStreamIdx] + s.headStreamExpTime = s.headStreamExpTime.Add(s.streamDuration) + } +} + +// flushColdBuf needs mtx locked. +func (s *summary) flushColdBuf() { + for _, v := range s.coldBuf { + for _, stream := range s.streams { + stream.Insert(v) + } + s.cnt++ + s.sum += v + } + s.coldBuf = s.coldBuf[0:0] + s.maybeRotateStreams() +} + +// swapBufs needs mtx AND bufMtx locked, coldBuf must be empty. +func (s *summary) swapBufs(now time.Time) { + if len(s.coldBuf) != 0 { + panic("coldBuf is not empty") + } + s.hotBuf, s.coldBuf = s.coldBuf, s.hotBuf + // hotBuf is now empty and gets new expiration set. + for now.After(s.hotBufExpTime) { + s.hotBufExpTime = s.hotBufExpTime.Add(s.streamDuration) + } +} + +type summaryCounts struct { + // sumBits contains the bits of the float64 representing the sum of all + // observations. sumBits and count have to go first in the struct to + // guarantee alignment for atomic operations. + // http://golang.org/pkg/sync/atomic/#pkg-note-BUG + sumBits uint64 + count uint64 +} + +type noObjectivesSummary struct { + // countAndHotIdx enables lock-free writes with use of atomic updates. + // The most significant bit is the hot index [0 or 1] of the count field + // below. Observe calls update the hot one. All remaining bits count the + // number of Observe calls. Observe starts by incrementing this counter, + // and finish by incrementing the count field in the respective + // summaryCounts, as a marker for completion. + // + // Calls of the Write method (which are non-mutating reads from the + // perspective of the summary) swap the hot–cold under the writeMtx + // lock. A cooldown is awaited (while locked) by comparing the number of + // observations with the initiation count. Once they match, then the + // last observation on the now cool one has completed. All cool fields must + // be merged into the new hot before releasing writeMtx. + + // Fields with atomic access first! See alignment constraint: + // http://golang.org/pkg/sync/atomic/#pkg-note-BUG + countAndHotIdx uint64 + + selfCollector + desc *Desc + writeMtx sync.Mutex // Only used in the Write method. + + // Two counts, one is "hot" for lock-free observations, the other is + // "cold" for writing out a dto.Metric. It has to be an array of + // pointers to guarantee 64bit alignment of the histogramCounts, see + // http://golang.org/pkg/sync/atomic/#pkg-note-BUG. + counts [2]*summaryCounts + + labelPairs []*dto.LabelPair +} + +func (s *noObjectivesSummary) Desc() *Desc { + return s.desc +} + +func (s *noObjectivesSummary) Observe(v float64) { + // We increment h.countAndHotIdx so that the counter in the lower + // 63 bits gets incremented. At the same time, we get the new value + // back, which we can use to find the currently-hot counts. + n := atomic.AddUint64(&s.countAndHotIdx, 1) + hotCounts := s.counts[n>>63] + + for { + oldBits := atomic.LoadUint64(&hotCounts.sumBits) + newBits := math.Float64bits(math.Float64frombits(oldBits) + v) + if atomic.CompareAndSwapUint64(&hotCounts.sumBits, oldBits, newBits) { + break + } + } + // Increment count last as we take it as a signal that the observation + // is complete. + atomic.AddUint64(&hotCounts.count, 1) +} + +func (s *noObjectivesSummary) Write(out *dto.Metric) error { + // For simplicity, we protect this whole method by a mutex. It is not in + // the hot path, i.e. Observe is called much more often than Write. The + // complication of making Write lock-free isn't worth it, if possible at + // all. + s.writeMtx.Lock() + defer s.writeMtx.Unlock() + + // Adding 1<<63 switches the hot index (from 0 to 1 or from 1 to 0) + // without touching the count bits. See the struct comments for a full + // description of the algorithm. + n := atomic.AddUint64(&s.countAndHotIdx, 1<<63) + // count is contained unchanged in the lower 63 bits. + count := n & ((1 << 63) - 1) + // The most significant bit tells us which counts is hot. The complement + // is thus the cold one. + hotCounts := s.counts[n>>63] + coldCounts := s.counts[(^n)>>63] + + // Await cooldown. + for count != atomic.LoadUint64(&coldCounts.count) { + runtime.Gosched() // Let observations get work done. + } + + sum := &dto.Summary{ + SampleCount: proto.Uint64(count), + SampleSum: proto.Float64(math.Float64frombits(atomic.LoadUint64(&coldCounts.sumBits))), + } + + out.Summary = sum + out.Label = s.labelPairs + + // Finally add all the cold counts to the new hot counts and reset the cold counts. + atomic.AddUint64(&hotCounts.count, count) + atomic.StoreUint64(&coldCounts.count, 0) + for { + oldBits := atomic.LoadUint64(&hotCounts.sumBits) + newBits := math.Float64bits(math.Float64frombits(oldBits) + sum.GetSampleSum()) + if atomic.CompareAndSwapUint64(&hotCounts.sumBits, oldBits, newBits) { + atomic.StoreUint64(&coldCounts.sumBits, 0) + break + } + } + return nil +} + +type quantSort []*dto.Quantile + +func (s quantSort) Len() int { + return len(s) +} + +func (s quantSort) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s quantSort) Less(i, j int) bool { + return s[i].GetQuantile() < s[j].GetQuantile() +} + +// SummaryVec is a Collector that bundles a set of Summaries that all share the +// same Desc, but have different values for their variable labels. This is used +// if you want to count the same thing partitioned by various dimensions +// (e.g. HTTP request latencies, partitioned by status code and method). Create +// instances with NewSummaryVec. +type SummaryVec struct { + *MetricVec +} + +// NewSummaryVec creates a new SummaryVec based on the provided SummaryOpts and +// partitioned by the given label names. +// +// Due to the way a Summary is represented in the Prometheus text format and how +// it is handled by the Prometheus server internally, “quantile” is an illegal +// label name. NewSummaryVec will panic if this label name is used. +func NewSummaryVec(opts SummaryOpts, labelNames []string) *SummaryVec { + return V2.NewSummaryVec(SummaryVecOpts{ + SummaryOpts: opts, + VariableLabels: UnconstrainedLabels(labelNames), + }) +} + +// NewSummaryVec creates a new SummaryVec based on the provided SummaryVecOpts. +func (v2) NewSummaryVec(opts SummaryVecOpts) *SummaryVec { + for _, ln := range opts.VariableLabels.labelNames() { + if ln == quantileLabel { + panic(errQuantileLabelNotAllowed) + } + } + desc := V2.NewDesc( + BuildFQName(opts.Namespace, opts.Subsystem, opts.Name), + opts.Help, + opts.VariableLabels, + opts.ConstLabels, + ) + return &SummaryVec{ + MetricVec: NewMetricVec(desc, func(lvs ...string) Metric { + return newSummary(desc, opts.SummaryOpts, lvs...) + }), + } +} + +// GetMetricWithLabelValues returns the Summary for the given slice of label +// values (same order as the variable labels in Desc). If that combination of +// label values is accessed for the first time, a new Summary is created. +// +// It is possible to call this method without using the returned Summary to only +// create the new Summary but leave it at its starting value, a Summary without +// any observations. +// +// Keeping the Summary for later use is possible (and should be considered if +// performance is critical), but keep in mind that Reset, DeleteLabelValues and +// Delete can be used to delete the Summary from the SummaryVec. In that case, +// the Summary will still exist, but it will not be exported anymore, even if a +// Summary with the same label values is created later. See also the CounterVec +// example. +// +// An error is returned if the number of label values is not the same as the +// number of variable labels in Desc (minus any curried labels). +// +// Note that for more than one label value, this method is prone to mistakes +// caused by an incorrect order of arguments. Consider GetMetricWith(Labels) as +// an alternative to avoid that type of mistake. For higher label numbers, the +// latter has a much more readable (albeit more verbose) syntax, but it comes +// with a performance overhead (for creating and processing the Labels map). +// See also the GaugeVec example. +func (v *SummaryVec) GetMetricWithLabelValues(lvs ...string) (Observer, error) { + metric, err := v.MetricVec.GetMetricWithLabelValues(lvs...) + if metric != nil { + return metric.(Observer), err + } + return nil, err +} + +// GetMetricWith returns the Summary for the given Labels map (the label names +// must match those of the variable labels in Desc). If that label map is +// accessed for the first time, a new Summary is created. Implications of +// creating a Summary without using it and keeping the Summary for later use are +// the same as for GetMetricWithLabelValues. +// +// An error is returned if the number and names of the Labels are inconsistent +// with those of the variable labels in Desc (minus any curried labels). +// +// This method is used for the same purpose as +// GetMetricWithLabelValues(...string). See there for pros and cons of the two +// methods. +func (v *SummaryVec) GetMetricWith(labels Labels) (Observer, error) { + metric, err := v.MetricVec.GetMetricWith(labels) + if metric != nil { + return metric.(Observer), err + } + return nil, err +} + +// WithLabelValues works as GetMetricWithLabelValues, but panics where +// GetMetricWithLabelValues would have returned an error. Not returning an +// error allows shortcuts like +// +// myVec.WithLabelValues("404", "GET").Observe(42.21) +func (v *SummaryVec) WithLabelValues(lvs ...string) Observer { + s, err := v.GetMetricWithLabelValues(lvs...) + if err != nil { + panic(err) + } + return s +} + +// With works as GetMetricWith, but panics where GetMetricWithLabels would have +// returned an error. Not returning an error allows shortcuts like +// +// myVec.With(prometheus.Labels{"code": "404", "method": "GET"}).Observe(42.21) +func (v *SummaryVec) With(labels Labels) Observer { + s, err := v.GetMetricWith(labels) + if err != nil { + panic(err) + } + return s +} + +// CurryWith returns a vector curried with the provided labels, i.e. the +// returned vector has those labels pre-set for all labeled operations performed +// on it. The cardinality of the curried vector is reduced accordingly. The +// order of the remaining labels stays the same (just with the curried labels +// taken out of the sequence – which is relevant for the +// (GetMetric)WithLabelValues methods). It is possible to curry a curried +// vector, but only with labels not yet used for currying before. +// +// The metrics contained in the SummaryVec are shared between the curried and +// uncurried vectors. They are just accessed differently. Curried and uncurried +// vectors behave identically in terms of collection. Only one must be +// registered with a given registry (usually the uncurried version). The Reset +// method deletes all metrics, even if called on a curried vector. +func (v *SummaryVec) CurryWith(labels Labels) (ObserverVec, error) { + vec, err := v.MetricVec.CurryWith(labels) + if vec != nil { + return &SummaryVec{vec}, err + } + return nil, err +} + +// MustCurryWith works as CurryWith but panics where CurryWith would have +// returned an error. +func (v *SummaryVec) MustCurryWith(labels Labels) ObserverVec { + vec, err := v.CurryWith(labels) + if err != nil { + panic(err) + } + return vec +} + +type constSummary struct { + desc *Desc + count uint64 + sum float64 + quantiles map[float64]float64 + labelPairs []*dto.LabelPair +} + +func (s *constSummary) Desc() *Desc { + return s.desc +} + +func (s *constSummary) Write(out *dto.Metric) error { + sum := &dto.Summary{} + qs := make([]*dto.Quantile, 0, len(s.quantiles)) + + sum.SampleCount = proto.Uint64(s.count) + sum.SampleSum = proto.Float64(s.sum) + + for rank, q := range s.quantiles { + qs = append(qs, &dto.Quantile{ + Quantile: proto.Float64(rank), + Value: proto.Float64(q), + }) + } + + if len(qs) > 0 { + sort.Sort(quantSort(qs)) + } + sum.Quantile = qs + + out.Summary = sum + out.Label = s.labelPairs + + return nil +} + +// NewConstSummary returns a metric representing a Prometheus summary with fixed +// values for the count, sum, and quantiles. As those parameters cannot be +// changed, the returned value does not implement the Summary interface (but +// only the Metric interface). Users of this package will not have much use for +// it in regular operations. However, when implementing custom Collectors, it is +// useful as a throw-away metric that is generated on the fly to send it to +// Prometheus in the Collect method. +// +// quantiles maps ranks to quantile values. For example, a median latency of +// 0.23s and a 99th percentile latency of 0.56s would be expressed as: +// +// map[float64]float64{0.5: 0.23, 0.99: 0.56} +// +// NewConstSummary returns an error if the length of labelValues is not +// consistent with the variable labels in Desc or if Desc is invalid. +func NewConstSummary( + desc *Desc, + count uint64, + sum float64, + quantiles map[float64]float64, + labelValues ...string, +) (Metric, error) { + if desc.err != nil { + return nil, desc.err + } + if err := validateLabelValues(labelValues, len(desc.variableLabels)); err != nil { + return nil, err + } + return &constSummary{ + desc: desc, + count: count, + sum: sum, + quantiles: quantiles, + labelPairs: MakeLabelPairs(desc, labelValues), + }, nil +} + +// MustNewConstSummary is a version of NewConstSummary that panics where +// NewConstMetric would have returned an error. +func MustNewConstSummary( + desc *Desc, + count uint64, + sum float64, + quantiles map[float64]float64, + labelValues ...string, +) Metric { + m, err := NewConstSummary(desc, count, sum, quantiles, labelValues...) + if err != nil { + panic(err) + } + return m +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/value.go b/vendor/github.com/prometheus/client_golang/prometheus/value.go new file mode 100644 index 0000000000..5f6bb80014 --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/value.go @@ -0,0 +1,235 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "fmt" + "sort" + "time" + "unicode/utf8" + + "github.com/prometheus/client_golang/prometheus/internal" + + dto "github.com/prometheus/client_model/go" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// ValueType is an enumeration of metric types that represent a simple value. +type ValueType int + +// Possible values for the ValueType enum. Use UntypedValue to mark a metric +// with an unknown type. +const ( + _ ValueType = iota + CounterValue + GaugeValue + UntypedValue +) + +var ( + CounterMetricTypePtr = func() *dto.MetricType { d := dto.MetricType_COUNTER; return &d }() + GaugeMetricTypePtr = func() *dto.MetricType { d := dto.MetricType_GAUGE; return &d }() + UntypedMetricTypePtr = func() *dto.MetricType { d := dto.MetricType_UNTYPED; return &d }() +) + +func (v ValueType) ToDTO() *dto.MetricType { + switch v { + case CounterValue: + return CounterMetricTypePtr + case GaugeValue: + return GaugeMetricTypePtr + default: + return UntypedMetricTypePtr + } +} + +// valueFunc is a generic metric for simple values retrieved on collect time +// from a function. It implements Metric and Collector. Its effective type is +// determined by ValueType. This is a low-level building block used by the +// library to back the implementations of CounterFunc, GaugeFunc, and +// UntypedFunc. +type valueFunc struct { + selfCollector + + desc *Desc + valType ValueType + function func() float64 + labelPairs []*dto.LabelPair +} + +// newValueFunc returns a newly allocated valueFunc with the given Desc and +// ValueType. The value reported is determined by calling the given function +// from within the Write method. Take into account that metric collection may +// happen concurrently. If that results in concurrent calls to Write, like in +// the case where a valueFunc is directly registered with Prometheus, the +// provided function must be concurrency-safe. +func newValueFunc(desc *Desc, valueType ValueType, function func() float64) *valueFunc { + result := &valueFunc{ + desc: desc, + valType: valueType, + function: function, + labelPairs: MakeLabelPairs(desc, nil), + } + result.init(result) + return result +} + +func (v *valueFunc) Desc() *Desc { + return v.desc +} + +func (v *valueFunc) Write(out *dto.Metric) error { + return populateMetric(v.valType, v.function(), v.labelPairs, nil, out) +} + +// NewConstMetric returns a metric with one fixed value that cannot be +// changed. Users of this package will not have much use for it in regular +// operations. However, when implementing custom Collectors, it is useful as a +// throw-away metric that is generated on the fly to send it to Prometheus in +// the Collect method. NewConstMetric returns an error if the length of +// labelValues is not consistent with the variable labels in Desc or if Desc is +// invalid. +func NewConstMetric(desc *Desc, valueType ValueType, value float64, labelValues ...string) (Metric, error) { + if desc.err != nil { + return nil, desc.err + } + if err := validateLabelValues(labelValues, len(desc.variableLabels)); err != nil { + return nil, err + } + + metric := &dto.Metric{} + if err := populateMetric(valueType, value, MakeLabelPairs(desc, labelValues), nil, metric); err != nil { + return nil, err + } + + return &constMetric{ + desc: desc, + metric: metric, + }, nil +} + +// MustNewConstMetric is a version of NewConstMetric that panics where +// NewConstMetric would have returned an error. +func MustNewConstMetric(desc *Desc, valueType ValueType, value float64, labelValues ...string) Metric { + m, err := NewConstMetric(desc, valueType, value, labelValues...) + if err != nil { + panic(err) + } + return m +} + +type constMetric struct { + desc *Desc + metric *dto.Metric +} + +func (m *constMetric) Desc() *Desc { + return m.desc +} + +func (m *constMetric) Write(out *dto.Metric) error { + out.Label = m.metric.Label + out.Counter = m.metric.Counter + out.Gauge = m.metric.Gauge + out.Untyped = m.metric.Untyped + return nil +} + +func populateMetric( + t ValueType, + v float64, + labelPairs []*dto.LabelPair, + e *dto.Exemplar, + m *dto.Metric, +) error { + m.Label = labelPairs + switch t { + case CounterValue: + m.Counter = &dto.Counter{Value: proto.Float64(v), Exemplar: e} + case GaugeValue: + m.Gauge = &dto.Gauge{Value: proto.Float64(v)} + case UntypedValue: + m.Untyped = &dto.Untyped{Value: proto.Float64(v)} + default: + return fmt.Errorf("encountered unknown type %v", t) + } + return nil +} + +// MakeLabelPairs is a helper function to create protobuf LabelPairs from the +// variable and constant labels in the provided Desc. The values for the +// variable labels are defined by the labelValues slice, which must be in the +// same order as the corresponding variable labels in the Desc. +// +// This function is only needed for custom Metric implementations. See MetricVec +// example. +func MakeLabelPairs(desc *Desc, labelValues []string) []*dto.LabelPair { + totalLen := len(desc.variableLabels) + len(desc.constLabelPairs) + if totalLen == 0 { + // Super fast path. + return nil + } + if len(desc.variableLabels) == 0 { + // Moderately fast path. + return desc.constLabelPairs + } + labelPairs := make([]*dto.LabelPair, 0, totalLen) + for i, l := range desc.variableLabels { + labelPairs = append(labelPairs, &dto.LabelPair{ + Name: proto.String(l.Name), + Value: proto.String(labelValues[i]), + }) + } + labelPairs = append(labelPairs, desc.constLabelPairs...) + sort.Sort(internal.LabelPairSorter(labelPairs)) + return labelPairs +} + +// ExemplarMaxRunes is the max total number of runes allowed in exemplar labels. +const ExemplarMaxRunes = 128 + +// newExemplar creates a new dto.Exemplar from the provided values. An error is +// returned if any of the label names or values are invalid or if the total +// number of runes in the label names and values exceeds ExemplarMaxRunes. +func newExemplar(value float64, ts time.Time, l Labels) (*dto.Exemplar, error) { + e := &dto.Exemplar{} + e.Value = proto.Float64(value) + tsProto := timestamppb.New(ts) + if err := tsProto.CheckValid(); err != nil { + return nil, err + } + e.Timestamp = tsProto + labelPairs := make([]*dto.LabelPair, 0, len(l)) + var runes int + for name, value := range l { + if !checkLabelName(name) { + return nil, fmt.Errorf("exemplar label name %q is invalid", name) + } + runes += utf8.RuneCountInString(name) + if !utf8.ValidString(value) { + return nil, fmt.Errorf("exemplar label value %q is not valid UTF-8", value) + } + runes += utf8.RuneCountInString(value) + labelPairs = append(labelPairs, &dto.LabelPair{ + Name: proto.String(name), + Value: proto.String(value), + }) + } + if runes > ExemplarMaxRunes { + return nil, fmt.Errorf("exemplar labels have %d runes, exceeding the limit of %d", runes, ExemplarMaxRunes) + } + e.Label = labelPairs + return e, nil +} diff --git a/vendor/github.com/prometheus/client_golang/prometheus/vec.go b/vendor/github.com/prometheus/client_golang/prometheus/vec.go new file mode 100644 index 0000000000..f0d0015a0f --- /dev/null +++ b/vendor/github.com/prometheus/client_golang/prometheus/vec.go @@ -0,0 +1,703 @@ +// Copyright 2014 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prometheus + +import ( + "fmt" + "sync" + + "github.com/prometheus/common/model" +) + +var labelsPool = &sync.Pool{ + New: func() interface{} { + return make(Labels) + }, +} + +func getLabelsFromPool() Labels { + return labelsPool.Get().(Labels) +} + +func putLabelsToPool(labels Labels) { + for k := range labels { + delete(labels, k) + } + + labelsPool.Put(labels) +} + +// MetricVec is a Collector to bundle metrics of the same name that differ in +// their label values. MetricVec is not used directly but as a building block +// for implementations of vectors of a given metric type, like GaugeVec, +// CounterVec, SummaryVec, and HistogramVec. It is exported so that it can be +// used for custom Metric implementations. +// +// To create a FooVec for custom Metric Foo, embed a pointer to MetricVec in +// FooVec and initialize it with NewMetricVec. Implement wrappers for +// GetMetricWithLabelValues and GetMetricWith that return (Foo, error) rather +// than (Metric, error). Similarly, create a wrapper for CurryWith that returns +// (*FooVec, error) rather than (*MetricVec, error). It is recommended to also +// add the convenience methods WithLabelValues, With, and MustCurryWith, which +// panic instead of returning errors. See also the MetricVec example. +type MetricVec struct { + *metricMap + + curry []curriedLabelValue + + // hashAdd and hashAddByte can be replaced for testing collision handling. + hashAdd func(h uint64, s string) uint64 + hashAddByte func(h uint64, b byte) uint64 +} + +// NewMetricVec returns an initialized metricVec. +func NewMetricVec(desc *Desc, newMetric func(lvs ...string) Metric) *MetricVec { + return &MetricVec{ + metricMap: &metricMap{ + metrics: map[uint64][]metricWithLabelValues{}, + desc: desc, + newMetric: newMetric, + }, + hashAdd: hashAdd, + hashAddByte: hashAddByte, + } +} + +// DeleteLabelValues removes the metric where the variable labels are the same +// as those passed in as labels (same order as the VariableLabels in Desc). It +// returns true if a metric was deleted. +// +// It is not an error if the number of label values is not the same as the +// number of VariableLabels in Desc. However, such inconsistent label count can +// never match an actual metric, so the method will always return false in that +// case. +// +// Note that for more than one label value, this method is prone to mistakes +// caused by an incorrect order of arguments. Consider Delete(Labels) as an +// alternative to avoid that type of mistake. For higher label numbers, the +// latter has a much more readable (albeit more verbose) syntax, but it comes +// with a performance overhead (for creating and processing the Labels map). +// See also the CounterVec example. +func (m *MetricVec) DeleteLabelValues(lvs ...string) bool { + lvs = constrainLabelValues(m.desc, lvs, m.curry) + h, err := m.hashLabelValues(lvs) + if err != nil { + return false + } + + return m.metricMap.deleteByHashWithLabelValues(h, lvs, m.curry) +} + +// Delete deletes the metric where the variable labels are the same as those +// passed in as labels. It returns true if a metric was deleted. +// +// It is not an error if the number and names of the Labels are inconsistent +// with those of the VariableLabels in Desc. However, such inconsistent Labels +// can never match an actual metric, so the method will always return false in +// that case. +// +// This method is used for the same purpose as DeleteLabelValues(...string). See +// there for pros and cons of the two methods. +func (m *MetricVec) Delete(labels Labels) bool { + labels = constrainLabels(m.desc, labels) + defer putLabelsToPool(labels) + + h, err := m.hashLabels(labels) + if err != nil { + return false + } + + return m.metricMap.deleteByHashWithLabels(h, labels, m.curry) +} + +// DeletePartialMatch deletes all metrics where the variable labels contain all of those +// passed in as labels. The order of the labels does not matter. +// It returns the number of metrics deleted. +// +// Note that curried labels will never be matched if deleting from the curried vector. +// To match curried labels with DeletePartialMatch, it must be called on the base vector. +func (m *MetricVec) DeletePartialMatch(labels Labels) int { + labels = constrainLabels(m.desc, labels) + defer putLabelsToPool(labels) + + return m.metricMap.deleteByLabels(labels, m.curry) +} + +// Without explicit forwarding of Describe, Collect, Reset, those methods won't +// show up in GoDoc. + +// Describe implements Collector. +func (m *MetricVec) Describe(ch chan<- *Desc) { m.metricMap.Describe(ch) } + +// Collect implements Collector. +func (m *MetricVec) Collect(ch chan<- Metric) { m.metricMap.Collect(ch) } + +// Reset deletes all metrics in this vector. +func (m *MetricVec) Reset() { m.metricMap.Reset() } + +// CurryWith returns a vector curried with the provided labels, i.e. the +// returned vector has those labels pre-set for all labeled operations performed +// on it. The cardinality of the curried vector is reduced accordingly. The +// order of the remaining labels stays the same (just with the curried labels +// taken out of the sequence – which is relevant for the +// (GetMetric)WithLabelValues methods). It is possible to curry a curried +// vector, but only with labels not yet used for currying before. +// +// The metrics contained in the MetricVec are shared between the curried and +// uncurried vectors. They are just accessed differently. Curried and uncurried +// vectors behave identically in terms of collection. Only one must be +// registered with a given registry (usually the uncurried version). The Reset +// method deletes all metrics, even if called on a curried vector. +// +// Note that CurryWith is usually not called directly but through a wrapper +// around MetricVec, implementing a vector for a specific Metric +// implementation, for example GaugeVec. +func (m *MetricVec) CurryWith(labels Labels) (*MetricVec, error) { + var ( + newCurry []curriedLabelValue + oldCurry = m.curry + iCurry int + ) + for i, label := range m.desc.variableLabels { + val, ok := labels[label.Name] + if iCurry < len(oldCurry) && oldCurry[iCurry].index == i { + if ok { + return nil, fmt.Errorf("label name %q is already curried", label.Name) + } + newCurry = append(newCurry, oldCurry[iCurry]) + iCurry++ + } else { + if !ok { + continue // Label stays uncurried. + } + newCurry = append(newCurry, curriedLabelValue{i, label.Constrain(val)}) + } + } + if l := len(oldCurry) + len(labels) - len(newCurry); l > 0 { + return nil, fmt.Errorf("%d unknown label(s) found during currying", l) + } + + return &MetricVec{ + metricMap: m.metricMap, + curry: newCurry, + hashAdd: m.hashAdd, + hashAddByte: m.hashAddByte, + }, nil +} + +// GetMetricWithLabelValues returns the Metric for the given slice of label +// values (same order as the variable labels in Desc). If that combination of +// label values is accessed for the first time, a new Metric is created (by +// calling the newMetric function provided during construction of the +// MetricVec). +// +// It is possible to call this method without using the returned Metric to only +// create the new Metric but leave it in its initial state. +// +// Keeping the Metric for later use is possible (and should be considered if +// performance is critical), but keep in mind that Reset, DeleteLabelValues and +// Delete can be used to delete the Metric from the MetricVec. In that case, the +// Metric will still exist, but it will not be exported anymore, even if a +// Metric with the same label values is created later. +// +// An error is returned if the number of label values is not the same as the +// number of variable labels in Desc (minus any curried labels). +// +// Note that for more than one label value, this method is prone to mistakes +// caused by an incorrect order of arguments. Consider GetMetricWith(Labels) as +// an alternative to avoid that type of mistake. For higher label numbers, the +// latter has a much more readable (albeit more verbose) syntax, but it comes +// with a performance overhead (for creating and processing the Labels map). +// +// Note that GetMetricWithLabelValues is usually not called directly but through +// a wrapper around MetricVec, implementing a vector for a specific Metric +// implementation, for example GaugeVec. +func (m *MetricVec) GetMetricWithLabelValues(lvs ...string) (Metric, error) { + lvs = constrainLabelValues(m.desc, lvs, m.curry) + h, err := m.hashLabelValues(lvs) + if err != nil { + return nil, err + } + + return m.metricMap.getOrCreateMetricWithLabelValues(h, lvs, m.curry), nil +} + +// GetMetricWith returns the Metric for the given Labels map (the label names +// must match those of the variable labels in Desc). If that label map is +// accessed for the first time, a new Metric is created. Implications of +// creating a Metric without using it and keeping the Metric for later use +// are the same as for GetMetricWithLabelValues. +// +// An error is returned if the number and names of the Labels are inconsistent +// with those of the variable labels in Desc (minus any curried labels). +// +// This method is used for the same purpose as +// GetMetricWithLabelValues(...string). See there for pros and cons of the two +// methods. +// +// Note that GetMetricWith is usually not called directly but through a wrapper +// around MetricVec, implementing a vector for a specific Metric implementation, +// for example GaugeVec. +func (m *MetricVec) GetMetricWith(labels Labels) (Metric, error) { + labels = constrainLabels(m.desc, labels) + defer putLabelsToPool(labels) + + h, err := m.hashLabels(labels) + if err != nil { + return nil, err + } + + return m.metricMap.getOrCreateMetricWithLabels(h, labels, m.curry), nil +} + +func (m *MetricVec) hashLabelValues(vals []string) (uint64, error) { + if err := validateLabelValues(vals, len(m.desc.variableLabels)-len(m.curry)); err != nil { + return 0, err + } + + var ( + h = hashNew() + curry = m.curry + iVals, iCurry int + ) + for i := 0; i < len(m.desc.variableLabels); i++ { + if iCurry < len(curry) && curry[iCurry].index == i { + h = m.hashAdd(h, curry[iCurry].value) + iCurry++ + } else { + h = m.hashAdd(h, vals[iVals]) + iVals++ + } + h = m.hashAddByte(h, model.SeparatorByte) + } + return h, nil +} + +func (m *MetricVec) hashLabels(labels Labels) (uint64, error) { + if err := validateValuesInLabels(labels, len(m.desc.variableLabels)-len(m.curry)); err != nil { + return 0, err + } + + var ( + h = hashNew() + curry = m.curry + iCurry int + ) + for i, label := range m.desc.variableLabels { + val, ok := labels[label.Name] + if iCurry < len(curry) && curry[iCurry].index == i { + if ok { + return 0, fmt.Errorf("label name %q is already curried", label.Name) + } + h = m.hashAdd(h, curry[iCurry].value) + iCurry++ + } else { + if !ok { + return 0, fmt.Errorf("label name %q missing in label map", label.Name) + } + h = m.hashAdd(h, val) + } + h = m.hashAddByte(h, model.SeparatorByte) + } + return h, nil +} + +// metricWithLabelValues provides the metric and its label values for +// disambiguation on hash collision. +type metricWithLabelValues struct { + values []string + metric Metric +} + +// curriedLabelValue sets the curried value for a label at the given index. +type curriedLabelValue struct { + index int + value string +} + +// metricMap is a helper for metricVec and shared between differently curried +// metricVecs. +type metricMap struct { + mtx sync.RWMutex // Protects metrics. + metrics map[uint64][]metricWithLabelValues + desc *Desc + newMetric func(labelValues ...string) Metric +} + +// Describe implements Collector. It will send exactly one Desc to the provided +// channel. +func (m *metricMap) Describe(ch chan<- *Desc) { + ch <- m.desc +} + +// Collect implements Collector. +func (m *metricMap) Collect(ch chan<- Metric) { + m.mtx.RLock() + defer m.mtx.RUnlock() + + for _, metrics := range m.metrics { + for _, metric := range metrics { + ch <- metric.metric + } + } +} + +// Reset deletes all metrics in this vector. +func (m *metricMap) Reset() { + m.mtx.Lock() + defer m.mtx.Unlock() + + for h := range m.metrics { + delete(m.metrics, h) + } +} + +// deleteByHashWithLabelValues removes the metric from the hash bucket h. If +// there are multiple matches in the bucket, use lvs to select a metric and +// remove only that metric. +func (m *metricMap) deleteByHashWithLabelValues( + h uint64, lvs []string, curry []curriedLabelValue, +) bool { + m.mtx.Lock() + defer m.mtx.Unlock() + + metrics, ok := m.metrics[h] + if !ok { + return false + } + + i := findMetricWithLabelValues(metrics, lvs, curry) + if i >= len(metrics) { + return false + } + + if len(metrics) > 1 { + old := metrics + m.metrics[h] = append(metrics[:i], metrics[i+1:]...) + old[len(old)-1] = metricWithLabelValues{} + } else { + delete(m.metrics, h) + } + return true +} + +// deleteByHashWithLabels removes the metric from the hash bucket h. If there +// are multiple matches in the bucket, use lvs to select a metric and remove +// only that metric. +func (m *metricMap) deleteByHashWithLabels( + h uint64, labels Labels, curry []curriedLabelValue, +) bool { + m.mtx.Lock() + defer m.mtx.Unlock() + + metrics, ok := m.metrics[h] + if !ok { + return false + } + i := findMetricWithLabels(m.desc, metrics, labels, curry) + if i >= len(metrics) { + return false + } + + if len(metrics) > 1 { + old := metrics + m.metrics[h] = append(metrics[:i], metrics[i+1:]...) + old[len(old)-1] = metricWithLabelValues{} + } else { + delete(m.metrics, h) + } + return true +} + +// deleteByLabels deletes a metric if the given labels are present in the metric. +func (m *metricMap) deleteByLabels(labels Labels, curry []curriedLabelValue) int { + m.mtx.Lock() + defer m.mtx.Unlock() + + var numDeleted int + + for h, metrics := range m.metrics { + i := findMetricWithPartialLabels(m.desc, metrics, labels, curry) + if i >= len(metrics) { + // Didn't find matching labels in this metric slice. + continue + } + delete(m.metrics, h) + numDeleted++ + } + + return numDeleted +} + +// findMetricWithPartialLabel returns the index of the matching metric or +// len(metrics) if not found. +func findMetricWithPartialLabels( + desc *Desc, metrics []metricWithLabelValues, labels Labels, curry []curriedLabelValue, +) int { + for i, metric := range metrics { + if matchPartialLabels(desc, metric.values, labels, curry) { + return i + } + } + return len(metrics) +} + +// indexOf searches the given slice of strings for the target string and returns +// the index or len(items) as well as a boolean whether the search succeeded. +func indexOf(target string, items []string) (int, bool) { + for i, l := range items { + if l == target { + return i, true + } + } + return len(items), false +} + +// valueMatchesVariableOrCurriedValue determines if a value was previously curried, +// and returns whether it matches either the "base" value or the curried value accordingly. +// It also indicates whether the match is against a curried or uncurried value. +func valueMatchesVariableOrCurriedValue(targetValue string, index int, values []string, curry []curriedLabelValue) (bool, bool) { + for _, curriedValue := range curry { + if curriedValue.index == index { + // This label was curried. See if the curried value matches our target. + return curriedValue.value == targetValue, true + } + } + // This label was not curried. See if the current value matches our target label. + return values[index] == targetValue, false +} + +// matchPartialLabels searches the current metric and returns whether all of the target label:value pairs are present. +func matchPartialLabels(desc *Desc, values []string, labels Labels, curry []curriedLabelValue) bool { + for l, v := range labels { + // Check if the target label exists in our metrics and get the index. + varLabelIndex, validLabel := indexOf(l, desc.variableLabels.labelNames()) + if validLabel { + // Check the value of that label against the target value. + // We don't consider curried values in partial matches. + matches, curried := valueMatchesVariableOrCurriedValue(v, varLabelIndex, values, curry) + if matches && !curried { + continue + } + } + return false + } + return true +} + +// getOrCreateMetricWithLabelValues retrieves the metric by hash and label value +// or creates it and returns the new one. +// +// This function holds the mutex. +func (m *metricMap) getOrCreateMetricWithLabelValues( + hash uint64, lvs []string, curry []curriedLabelValue, +) Metric { + m.mtx.RLock() + metric, ok := m.getMetricWithHashAndLabelValues(hash, lvs, curry) + m.mtx.RUnlock() + if ok { + return metric + } + + m.mtx.Lock() + defer m.mtx.Unlock() + metric, ok = m.getMetricWithHashAndLabelValues(hash, lvs, curry) + if !ok { + inlinedLVs := inlineLabelValues(lvs, curry) + metric = m.newMetric(inlinedLVs...) + m.metrics[hash] = append(m.metrics[hash], metricWithLabelValues{values: inlinedLVs, metric: metric}) + } + return metric +} + +// getOrCreateMetricWithLabelValues retrieves the metric by hash and label value +// or creates it and returns the new one. +// +// This function holds the mutex. +func (m *metricMap) getOrCreateMetricWithLabels( + hash uint64, labels Labels, curry []curriedLabelValue, +) Metric { + m.mtx.RLock() + metric, ok := m.getMetricWithHashAndLabels(hash, labels, curry) + m.mtx.RUnlock() + if ok { + return metric + } + + m.mtx.Lock() + defer m.mtx.Unlock() + metric, ok = m.getMetricWithHashAndLabels(hash, labels, curry) + if !ok { + lvs := extractLabelValues(m.desc, labels, curry) + metric = m.newMetric(lvs...) + m.metrics[hash] = append(m.metrics[hash], metricWithLabelValues{values: lvs, metric: metric}) + } + return metric +} + +// getMetricWithHashAndLabelValues gets a metric while handling possible +// collisions in the hash space. Must be called while holding the read mutex. +func (m *metricMap) getMetricWithHashAndLabelValues( + h uint64, lvs []string, curry []curriedLabelValue, +) (Metric, bool) { + metrics, ok := m.metrics[h] + if ok { + if i := findMetricWithLabelValues(metrics, lvs, curry); i < len(metrics) { + return metrics[i].metric, true + } + } + return nil, false +} + +// getMetricWithHashAndLabels gets a metric while handling possible collisions in +// the hash space. Must be called while holding read mutex. +func (m *metricMap) getMetricWithHashAndLabels( + h uint64, labels Labels, curry []curriedLabelValue, +) (Metric, bool) { + metrics, ok := m.metrics[h] + if ok { + if i := findMetricWithLabels(m.desc, metrics, labels, curry); i < len(metrics) { + return metrics[i].metric, true + } + } + return nil, false +} + +// findMetricWithLabelValues returns the index of the matching metric or +// len(metrics) if not found. +func findMetricWithLabelValues( + metrics []metricWithLabelValues, lvs []string, curry []curriedLabelValue, +) int { + for i, metric := range metrics { + if matchLabelValues(metric.values, lvs, curry) { + return i + } + } + return len(metrics) +} + +// findMetricWithLabels returns the index of the matching metric or len(metrics) +// if not found. +func findMetricWithLabels( + desc *Desc, metrics []metricWithLabelValues, labels Labels, curry []curriedLabelValue, +) int { + for i, metric := range metrics { + if matchLabels(desc, metric.values, labels, curry) { + return i + } + } + return len(metrics) +} + +func matchLabelValues(values, lvs []string, curry []curriedLabelValue) bool { + if len(values) != len(lvs)+len(curry) { + return false + } + var iLVs, iCurry int + for i, v := range values { + if iCurry < len(curry) && curry[iCurry].index == i { + if v != curry[iCurry].value { + return false + } + iCurry++ + continue + } + if v != lvs[iLVs] { + return false + } + iLVs++ + } + return true +} + +func matchLabels(desc *Desc, values []string, labels Labels, curry []curriedLabelValue) bool { + if len(values) != len(labels)+len(curry) { + return false + } + iCurry := 0 + for i, k := range desc.variableLabels { + if iCurry < len(curry) && curry[iCurry].index == i { + if values[i] != curry[iCurry].value { + return false + } + iCurry++ + continue + } + if values[i] != labels[k.Name] { + return false + } + } + return true +} + +func extractLabelValues(desc *Desc, labels Labels, curry []curriedLabelValue) []string { + labelValues := make([]string, len(labels)+len(curry)) + iCurry := 0 + for i, k := range desc.variableLabels { + if iCurry < len(curry) && curry[iCurry].index == i { + labelValues[i] = curry[iCurry].value + iCurry++ + continue + } + labelValues[i] = labels[k.Name] + } + return labelValues +} + +func inlineLabelValues(lvs []string, curry []curriedLabelValue) []string { + labelValues := make([]string, len(lvs)+len(curry)) + var iCurry, iLVs int + for i := range labelValues { + if iCurry < len(curry) && curry[iCurry].index == i { + labelValues[i] = curry[iCurry].value + iCurry++ + continue + } + labelValues[i] = lvs[iLVs] + iLVs++ + } + return labelValues +} + +func constrainLabels(desc *Desc, labels Labels) Labels { + constrainedLabels := getLabelsFromPool() + for l, v := range labels { + if i, ok := indexOf(l, desc.variableLabels.labelNames()); ok { + v = desc.variableLabels[i].Constrain(v) + } + + constrainedLabels[l] = v + } + + return constrainedLabels +} + +func constrainLabelValues(desc *Desc, lvs []string, curry []curriedLabelValue) []string { + constrainedValues := make([]string, len(lvs)) + var iCurry, iLVs int + for i := 0; i < len(lvs)+len(curry); i++ { + if iCurry < len(curry) && curry[iCurry].index == i { + iCurry++ + continue + } + + if i < len(desc.variableLabels) { + constrainedValues[iLVs] = desc.variableLabels[i].Constrain(lvs[iLVs]) + } else { + constrainedValues[iLVs] = lvs[iLVs] + } + iLVs++ + } + return constrainedValues +} diff --git a/vendor/github.com/prometheus/client_model/go/metrics.pb.go b/vendor/github.com/prometheus/client_model/go/metrics.pb.go new file mode 100644 index 0000000000..2b5bca4b99 --- /dev/null +++ b/vendor/github.com/prometheus/client_model/go/metrics.pb.go @@ -0,0 +1,1332 @@ +// Copyright 2013 Prometheus Team +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.20.3 +// source: io/prometheus/client/metrics.proto + +package io_prometheus_client + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type MetricType int32 + +const ( + // COUNTER must use the Metric field "counter". + MetricType_COUNTER MetricType = 0 + // GAUGE must use the Metric field "gauge". + MetricType_GAUGE MetricType = 1 + // SUMMARY must use the Metric field "summary". + MetricType_SUMMARY MetricType = 2 + // UNTYPED must use the Metric field "untyped". + MetricType_UNTYPED MetricType = 3 + // HISTOGRAM must use the Metric field "histogram". + MetricType_HISTOGRAM MetricType = 4 + // GAUGE_HISTOGRAM must use the Metric field "histogram". + MetricType_GAUGE_HISTOGRAM MetricType = 5 +) + +// Enum value maps for MetricType. +var ( + MetricType_name = map[int32]string{ + 0: "COUNTER", + 1: "GAUGE", + 2: "SUMMARY", + 3: "UNTYPED", + 4: "HISTOGRAM", + 5: "GAUGE_HISTOGRAM", + } + MetricType_value = map[string]int32{ + "COUNTER": 0, + "GAUGE": 1, + "SUMMARY": 2, + "UNTYPED": 3, + "HISTOGRAM": 4, + "GAUGE_HISTOGRAM": 5, + } +) + +func (x MetricType) Enum() *MetricType { + p := new(MetricType) + *p = x + return p +} + +func (x MetricType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (MetricType) Descriptor() protoreflect.EnumDescriptor { + return file_io_prometheus_client_metrics_proto_enumTypes[0].Descriptor() +} + +func (MetricType) Type() protoreflect.EnumType { + return &file_io_prometheus_client_metrics_proto_enumTypes[0] +} + +func (x MetricType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *MetricType) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = MetricType(num) + return nil +} + +// Deprecated: Use MetricType.Descriptor instead. +func (MetricType) EnumDescriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{0} +} + +type LabelPair struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Value *string `protobuf:"bytes,2,opt,name=value" json:"value,omitempty"` +} + +func (x *LabelPair) Reset() { + *x = LabelPair{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *LabelPair) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LabelPair) ProtoMessage() {} + +func (x *LabelPair) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LabelPair.ProtoReflect.Descriptor instead. +func (*LabelPair) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{0} +} + +func (x *LabelPair) GetName() string { + if x != nil && x.Name != nil { + return *x.Name + } + return "" +} + +func (x *LabelPair) GetValue() string { + if x != nil && x.Value != nil { + return *x.Value + } + return "" +} + +type Gauge struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Value *float64 `protobuf:"fixed64,1,opt,name=value" json:"value,omitempty"` +} + +func (x *Gauge) Reset() { + *x = Gauge{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Gauge) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Gauge) ProtoMessage() {} + +func (x *Gauge) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Gauge.ProtoReflect.Descriptor instead. +func (*Gauge) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{1} +} + +func (x *Gauge) GetValue() float64 { + if x != nil && x.Value != nil { + return *x.Value + } + return 0 +} + +type Counter struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Value *float64 `protobuf:"fixed64,1,opt,name=value" json:"value,omitempty"` + Exemplar *Exemplar `protobuf:"bytes,2,opt,name=exemplar" json:"exemplar,omitempty"` +} + +func (x *Counter) Reset() { + *x = Counter{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Counter) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Counter) ProtoMessage() {} + +func (x *Counter) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Counter.ProtoReflect.Descriptor instead. +func (*Counter) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{2} +} + +func (x *Counter) GetValue() float64 { + if x != nil && x.Value != nil { + return *x.Value + } + return 0 +} + +func (x *Counter) GetExemplar() *Exemplar { + if x != nil { + return x.Exemplar + } + return nil +} + +type Quantile struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Quantile *float64 `protobuf:"fixed64,1,opt,name=quantile" json:"quantile,omitempty"` + Value *float64 `protobuf:"fixed64,2,opt,name=value" json:"value,omitempty"` +} + +func (x *Quantile) Reset() { + *x = Quantile{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Quantile) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Quantile) ProtoMessage() {} + +func (x *Quantile) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Quantile.ProtoReflect.Descriptor instead. +func (*Quantile) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{3} +} + +func (x *Quantile) GetQuantile() float64 { + if x != nil && x.Quantile != nil { + return *x.Quantile + } + return 0 +} + +func (x *Quantile) GetValue() float64 { + if x != nil && x.Value != nil { + return *x.Value + } + return 0 +} + +type Summary struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SampleCount *uint64 `protobuf:"varint,1,opt,name=sample_count,json=sampleCount" json:"sample_count,omitempty"` + SampleSum *float64 `protobuf:"fixed64,2,opt,name=sample_sum,json=sampleSum" json:"sample_sum,omitempty"` + Quantile []*Quantile `protobuf:"bytes,3,rep,name=quantile" json:"quantile,omitempty"` +} + +func (x *Summary) Reset() { + *x = Summary{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Summary) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Summary) ProtoMessage() {} + +func (x *Summary) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Summary.ProtoReflect.Descriptor instead. +func (*Summary) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{4} +} + +func (x *Summary) GetSampleCount() uint64 { + if x != nil && x.SampleCount != nil { + return *x.SampleCount + } + return 0 +} + +func (x *Summary) GetSampleSum() float64 { + if x != nil && x.SampleSum != nil { + return *x.SampleSum + } + return 0 +} + +func (x *Summary) GetQuantile() []*Quantile { + if x != nil { + return x.Quantile + } + return nil +} + +type Untyped struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Value *float64 `protobuf:"fixed64,1,opt,name=value" json:"value,omitempty"` +} + +func (x *Untyped) Reset() { + *x = Untyped{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Untyped) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Untyped) ProtoMessage() {} + +func (x *Untyped) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Untyped.ProtoReflect.Descriptor instead. +func (*Untyped) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{5} +} + +func (x *Untyped) GetValue() float64 { + if x != nil && x.Value != nil { + return *x.Value + } + return 0 +} + +type Histogram struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SampleCount *uint64 `protobuf:"varint,1,opt,name=sample_count,json=sampleCount" json:"sample_count,omitempty"` + SampleCountFloat *float64 `protobuf:"fixed64,4,opt,name=sample_count_float,json=sampleCountFloat" json:"sample_count_float,omitempty"` // Overrides sample_count if > 0. + SampleSum *float64 `protobuf:"fixed64,2,opt,name=sample_sum,json=sampleSum" json:"sample_sum,omitempty"` + // Buckets for the conventional histogram. + Bucket []*Bucket `protobuf:"bytes,3,rep,name=bucket" json:"bucket,omitempty"` // Ordered in increasing order of upper_bound, +Inf bucket is optional. + // schema defines the bucket schema. Currently, valid numbers are -4 <= n <= 8. + // They are all for base-2 bucket schemas, where 1 is a bucket boundary in each case, and + // then each power of two is divided into 2^n logarithmic buckets. + // Or in other words, each bucket boundary is the previous boundary times 2^(2^-n). + // In the future, more bucket schemas may be added using numbers < -4 or > 8. + Schema *int32 `protobuf:"zigzag32,5,opt,name=schema" json:"schema,omitempty"` + ZeroThreshold *float64 `protobuf:"fixed64,6,opt,name=zero_threshold,json=zeroThreshold" json:"zero_threshold,omitempty"` // Breadth of the zero bucket. + ZeroCount *uint64 `protobuf:"varint,7,opt,name=zero_count,json=zeroCount" json:"zero_count,omitempty"` // Count in zero bucket. + ZeroCountFloat *float64 `protobuf:"fixed64,8,opt,name=zero_count_float,json=zeroCountFloat" json:"zero_count_float,omitempty"` // Overrides sb_zero_count if > 0. + // Negative buckets for the native histogram. + NegativeSpan []*BucketSpan `protobuf:"bytes,9,rep,name=negative_span,json=negativeSpan" json:"negative_span,omitempty"` + // Use either "negative_delta" or "negative_count", the former for + // regular histograms with integer counts, the latter for float + // histograms. + NegativeDelta []int64 `protobuf:"zigzag64,10,rep,name=negative_delta,json=negativeDelta" json:"negative_delta,omitempty"` // Count delta of each bucket compared to previous one (or to zero for 1st bucket). + NegativeCount []float64 `protobuf:"fixed64,11,rep,name=negative_count,json=negativeCount" json:"negative_count,omitempty"` // Absolute count of each bucket. + // Positive buckets for the native histogram. + PositiveSpan []*BucketSpan `protobuf:"bytes,12,rep,name=positive_span,json=positiveSpan" json:"positive_span,omitempty"` + // Use either "positive_delta" or "positive_count", the former for + // regular histograms with integer counts, the latter for float + // histograms. + PositiveDelta []int64 `protobuf:"zigzag64,13,rep,name=positive_delta,json=positiveDelta" json:"positive_delta,omitempty"` // Count delta of each bucket compared to previous one (or to zero for 1st bucket). + PositiveCount []float64 `protobuf:"fixed64,14,rep,name=positive_count,json=positiveCount" json:"positive_count,omitempty"` // Absolute count of each bucket. +} + +func (x *Histogram) Reset() { + *x = Histogram{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Histogram) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Histogram) ProtoMessage() {} + +func (x *Histogram) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Histogram.ProtoReflect.Descriptor instead. +func (*Histogram) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{6} +} + +func (x *Histogram) GetSampleCount() uint64 { + if x != nil && x.SampleCount != nil { + return *x.SampleCount + } + return 0 +} + +func (x *Histogram) GetSampleCountFloat() float64 { + if x != nil && x.SampleCountFloat != nil { + return *x.SampleCountFloat + } + return 0 +} + +func (x *Histogram) GetSampleSum() float64 { + if x != nil && x.SampleSum != nil { + return *x.SampleSum + } + return 0 +} + +func (x *Histogram) GetBucket() []*Bucket { + if x != nil { + return x.Bucket + } + return nil +} + +func (x *Histogram) GetSchema() int32 { + if x != nil && x.Schema != nil { + return *x.Schema + } + return 0 +} + +func (x *Histogram) GetZeroThreshold() float64 { + if x != nil && x.ZeroThreshold != nil { + return *x.ZeroThreshold + } + return 0 +} + +func (x *Histogram) GetZeroCount() uint64 { + if x != nil && x.ZeroCount != nil { + return *x.ZeroCount + } + return 0 +} + +func (x *Histogram) GetZeroCountFloat() float64 { + if x != nil && x.ZeroCountFloat != nil { + return *x.ZeroCountFloat + } + return 0 +} + +func (x *Histogram) GetNegativeSpan() []*BucketSpan { + if x != nil { + return x.NegativeSpan + } + return nil +} + +func (x *Histogram) GetNegativeDelta() []int64 { + if x != nil { + return x.NegativeDelta + } + return nil +} + +func (x *Histogram) GetNegativeCount() []float64 { + if x != nil { + return x.NegativeCount + } + return nil +} + +func (x *Histogram) GetPositiveSpan() []*BucketSpan { + if x != nil { + return x.PositiveSpan + } + return nil +} + +func (x *Histogram) GetPositiveDelta() []int64 { + if x != nil { + return x.PositiveDelta + } + return nil +} + +func (x *Histogram) GetPositiveCount() []float64 { + if x != nil { + return x.PositiveCount + } + return nil +} + +// A Bucket of a conventional histogram, each of which is treated as +// an individual counter-like time series by Prometheus. +type Bucket struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CumulativeCount *uint64 `protobuf:"varint,1,opt,name=cumulative_count,json=cumulativeCount" json:"cumulative_count,omitempty"` // Cumulative in increasing order. + CumulativeCountFloat *float64 `protobuf:"fixed64,4,opt,name=cumulative_count_float,json=cumulativeCountFloat" json:"cumulative_count_float,omitempty"` // Overrides cumulative_count if > 0. + UpperBound *float64 `protobuf:"fixed64,2,opt,name=upper_bound,json=upperBound" json:"upper_bound,omitempty"` // Inclusive. + Exemplar *Exemplar `protobuf:"bytes,3,opt,name=exemplar" json:"exemplar,omitempty"` +} + +func (x *Bucket) Reset() { + *x = Bucket{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Bucket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Bucket) ProtoMessage() {} + +func (x *Bucket) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Bucket.ProtoReflect.Descriptor instead. +func (*Bucket) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{7} +} + +func (x *Bucket) GetCumulativeCount() uint64 { + if x != nil && x.CumulativeCount != nil { + return *x.CumulativeCount + } + return 0 +} + +func (x *Bucket) GetCumulativeCountFloat() float64 { + if x != nil && x.CumulativeCountFloat != nil { + return *x.CumulativeCountFloat + } + return 0 +} + +func (x *Bucket) GetUpperBound() float64 { + if x != nil && x.UpperBound != nil { + return *x.UpperBound + } + return 0 +} + +func (x *Bucket) GetExemplar() *Exemplar { + if x != nil { + return x.Exemplar + } + return nil +} + +// A BucketSpan defines a number of consecutive buckets in a native +// histogram with their offset. Logically, it would be more +// straightforward to include the bucket counts in the Span. However, +// the protobuf representation is more compact in the way the data is +// structured here (with all the buckets in a single array separate +// from the Spans). +type BucketSpan struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Offset *int32 `protobuf:"zigzag32,1,opt,name=offset" json:"offset,omitempty"` // Gap to previous span, or starting point for 1st span (which can be negative). + Length *uint32 `protobuf:"varint,2,opt,name=length" json:"length,omitempty"` // Length of consecutive buckets. +} + +func (x *BucketSpan) Reset() { + *x = BucketSpan{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *BucketSpan) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BucketSpan) ProtoMessage() {} + +func (x *BucketSpan) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BucketSpan.ProtoReflect.Descriptor instead. +func (*BucketSpan) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{8} +} + +func (x *BucketSpan) GetOffset() int32 { + if x != nil && x.Offset != nil { + return *x.Offset + } + return 0 +} + +func (x *BucketSpan) GetLength() uint32 { + if x != nil && x.Length != nil { + return *x.Length + } + return 0 +} + +type Exemplar struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Label []*LabelPair `protobuf:"bytes,1,rep,name=label" json:"label,omitempty"` + Value *float64 `protobuf:"fixed64,2,opt,name=value" json:"value,omitempty"` + Timestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=timestamp" json:"timestamp,omitempty"` // OpenMetrics-style. +} + +func (x *Exemplar) Reset() { + *x = Exemplar{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Exemplar) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Exemplar) ProtoMessage() {} + +func (x *Exemplar) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Exemplar.ProtoReflect.Descriptor instead. +func (*Exemplar) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{9} +} + +func (x *Exemplar) GetLabel() []*LabelPair { + if x != nil { + return x.Label + } + return nil +} + +func (x *Exemplar) GetValue() float64 { + if x != nil && x.Value != nil { + return *x.Value + } + return 0 +} + +func (x *Exemplar) GetTimestamp() *timestamppb.Timestamp { + if x != nil { + return x.Timestamp + } + return nil +} + +type Metric struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Label []*LabelPair `protobuf:"bytes,1,rep,name=label" json:"label,omitempty"` + Gauge *Gauge `protobuf:"bytes,2,opt,name=gauge" json:"gauge,omitempty"` + Counter *Counter `protobuf:"bytes,3,opt,name=counter" json:"counter,omitempty"` + Summary *Summary `protobuf:"bytes,4,opt,name=summary" json:"summary,omitempty"` + Untyped *Untyped `protobuf:"bytes,5,opt,name=untyped" json:"untyped,omitempty"` + Histogram *Histogram `protobuf:"bytes,7,opt,name=histogram" json:"histogram,omitempty"` + TimestampMs *int64 `protobuf:"varint,6,opt,name=timestamp_ms,json=timestampMs" json:"timestamp_ms,omitempty"` +} + +func (x *Metric) Reset() { + *x = Metric{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Metric) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Metric) ProtoMessage() {} + +func (x *Metric) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Metric.ProtoReflect.Descriptor instead. +func (*Metric) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{10} +} + +func (x *Metric) GetLabel() []*LabelPair { + if x != nil { + return x.Label + } + return nil +} + +func (x *Metric) GetGauge() *Gauge { + if x != nil { + return x.Gauge + } + return nil +} + +func (x *Metric) GetCounter() *Counter { + if x != nil { + return x.Counter + } + return nil +} + +func (x *Metric) GetSummary() *Summary { + if x != nil { + return x.Summary + } + return nil +} + +func (x *Metric) GetUntyped() *Untyped { + if x != nil { + return x.Untyped + } + return nil +} + +func (x *Metric) GetHistogram() *Histogram { + if x != nil { + return x.Histogram + } + return nil +} + +func (x *Metric) GetTimestampMs() int64 { + if x != nil && x.TimestampMs != nil { + return *x.TimestampMs + } + return 0 +} + +type MetricFamily struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Help *string `protobuf:"bytes,2,opt,name=help" json:"help,omitempty"` + Type *MetricType `protobuf:"varint,3,opt,name=type,enum=io.prometheus.client.MetricType" json:"type,omitempty"` + Metric []*Metric `protobuf:"bytes,4,rep,name=metric" json:"metric,omitempty"` +} + +func (x *MetricFamily) Reset() { + *x = MetricFamily{} + if protoimpl.UnsafeEnabled { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MetricFamily) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MetricFamily) ProtoMessage() {} + +func (x *MetricFamily) ProtoReflect() protoreflect.Message { + mi := &file_io_prometheus_client_metrics_proto_msgTypes[11] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MetricFamily.ProtoReflect.Descriptor instead. +func (*MetricFamily) Descriptor() ([]byte, []int) { + return file_io_prometheus_client_metrics_proto_rawDescGZIP(), []int{11} +} + +func (x *MetricFamily) GetName() string { + if x != nil && x.Name != nil { + return *x.Name + } + return "" +} + +func (x *MetricFamily) GetHelp() string { + if x != nil && x.Help != nil { + return *x.Help + } + return "" +} + +func (x *MetricFamily) GetType() MetricType { + if x != nil && x.Type != nil { + return *x.Type + } + return MetricType_COUNTER +} + +func (x *MetricFamily) GetMetric() []*Metric { + if x != nil { + return x.Metric + } + return nil +} + +var File_io_prometheus_client_metrics_proto protoreflect.FileDescriptor + +var file_io_prometheus_client_metrics_proto_rawDesc = []byte{ + 0x0a, 0x22, 0x69, 0x6f, 0x2f, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2f, + 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2f, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x73, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x14, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, + 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, + 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x35, 0x0a, 0x09, 0x4c, + 0x61, 0x62, 0x65, 0x6c, 0x50, 0x61, 0x69, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x22, 0x1d, 0x0a, 0x05, 0x47, 0x61, 0x75, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x01, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x22, 0x5b, 0x0a, 0x07, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x14, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x01, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x12, 0x3a, 0x0a, 0x08, 0x65, 0x78, 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x72, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, + 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x78, 0x65, 0x6d, + 0x70, 0x6c, 0x61, 0x72, 0x52, 0x08, 0x65, 0x78, 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x72, 0x22, 0x3c, + 0x0a, 0x08, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x6c, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x71, 0x75, + 0x61, 0x6e, 0x74, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x01, 0x52, 0x08, 0x71, 0x75, + 0x61, 0x6e, 0x74, 0x69, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x01, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x87, 0x01, 0x0a, + 0x07, 0x53, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x61, 0x6d, 0x70, + 0x6c, 0x65, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0b, + 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x73, 0x75, 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x01, 0x52, + 0x09, 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x53, 0x75, 0x6d, 0x12, 0x3a, 0x0a, 0x08, 0x71, 0x75, + 0x61, 0x6e, 0x74, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x69, + 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x2e, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x69, 0x6c, 0x65, 0x52, 0x08, 0x71, 0x75, + 0x61, 0x6e, 0x74, 0x69, 0x6c, 0x65, 0x22, 0x1f, 0x0a, 0x07, 0x55, 0x6e, 0x74, 0x79, 0x70, 0x65, + 0x64, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x01, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0xe3, 0x04, 0x0a, 0x09, 0x48, 0x69, 0x73, 0x74, + 0x6f, 0x67, 0x72, 0x61, 0x6d, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x5f, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x0b, 0x73, 0x61, 0x6d, + 0x70, 0x6c, 0x65, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x2c, 0x0a, 0x12, 0x73, 0x61, 0x6d, 0x70, + 0x6c, 0x65, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x66, 0x6c, 0x6f, 0x61, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x01, 0x52, 0x10, 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x43, 0x6f, 0x75, 0x6e, + 0x74, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, + 0x5f, 0x73, 0x75, 0x6d, 0x18, 0x02, 0x20, 0x01, 0x28, 0x01, 0x52, 0x09, 0x73, 0x61, 0x6d, 0x70, + 0x6c, 0x65, 0x53, 0x75, 0x6d, 0x12, 0x34, 0x0a, 0x06, 0x62, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x18, + 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, + 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x42, 0x75, 0x63, + 0x6b, 0x65, 0x74, 0x52, 0x06, 0x62, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x73, + 0x63, 0x68, 0x65, 0x6d, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x11, 0x52, 0x06, 0x73, 0x63, 0x68, + 0x65, 0x6d, 0x61, 0x12, 0x25, 0x0a, 0x0e, 0x7a, 0x65, 0x72, 0x6f, 0x5f, 0x74, 0x68, 0x72, 0x65, + 0x73, 0x68, 0x6f, 0x6c, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x01, 0x52, 0x0d, 0x7a, 0x65, 0x72, + 0x6f, 0x54, 0x68, 0x72, 0x65, 0x73, 0x68, 0x6f, 0x6c, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x7a, 0x65, + 0x72, 0x6f, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, + 0x7a, 0x65, 0x72, 0x6f, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x28, 0x0a, 0x10, 0x7a, 0x65, 0x72, + 0x6f, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x66, 0x6c, 0x6f, 0x61, 0x74, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x01, 0x52, 0x0e, 0x7a, 0x65, 0x72, 0x6f, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x46, 0x6c, + 0x6f, 0x61, 0x74, 0x12, 0x45, 0x0a, 0x0d, 0x6e, 0x65, 0x67, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, + 0x73, 0x70, 0x61, 0x6e, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x69, 0x6f, 0x2e, + 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x2e, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x53, 0x70, 0x61, 0x6e, 0x52, 0x0c, 0x6e, 0x65, + 0x67, 0x61, 0x74, 0x69, 0x76, 0x65, 0x53, 0x70, 0x61, 0x6e, 0x12, 0x25, 0x0a, 0x0e, 0x6e, 0x65, + 0x67, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x64, 0x65, 0x6c, 0x74, 0x61, 0x18, 0x0a, 0x20, 0x03, + 0x28, 0x12, 0x52, 0x0d, 0x6e, 0x65, 0x67, 0x61, 0x74, 0x69, 0x76, 0x65, 0x44, 0x65, 0x6c, 0x74, + 0x61, 0x12, 0x25, 0x0a, 0x0e, 0x6e, 0x65, 0x67, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x18, 0x0b, 0x20, 0x03, 0x28, 0x01, 0x52, 0x0d, 0x6e, 0x65, 0x67, 0x61, 0x74, + 0x69, 0x76, 0x65, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x45, 0x0a, 0x0d, 0x70, 0x6f, 0x73, 0x69, + 0x74, 0x69, 0x76, 0x65, 0x5f, 0x73, 0x70, 0x61, 0x6e, 0x18, 0x0c, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x20, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, + 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x53, 0x70, 0x61, + 0x6e, 0x52, 0x0c, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, 0x65, 0x53, 0x70, 0x61, 0x6e, 0x12, + 0x25, 0x0a, 0x0e, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x64, 0x65, 0x6c, 0x74, + 0x61, 0x18, 0x0d, 0x20, 0x03, 0x28, 0x12, 0x52, 0x0d, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, + 0x65, 0x44, 0x65, 0x6c, 0x74, 0x61, 0x12, 0x25, 0x0a, 0x0e, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, + 0x76, 0x65, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x0e, 0x20, 0x03, 0x28, 0x01, 0x52, 0x0d, + 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, 0x65, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x22, 0xc6, 0x01, + 0x0a, 0x06, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x63, 0x75, 0x6d, 0x75, + 0x6c, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x04, 0x52, 0x0f, 0x63, 0x75, 0x6d, 0x75, 0x6c, 0x61, 0x74, 0x69, 0x76, 0x65, 0x43, 0x6f, + 0x75, 0x6e, 0x74, 0x12, 0x34, 0x0a, 0x16, 0x63, 0x75, 0x6d, 0x75, 0x6c, 0x61, 0x74, 0x69, 0x76, + 0x65, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x5f, 0x66, 0x6c, 0x6f, 0x61, 0x74, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x01, 0x52, 0x14, 0x63, 0x75, 0x6d, 0x75, 0x6c, 0x61, 0x74, 0x69, 0x76, 0x65, 0x43, + 0x6f, 0x75, 0x6e, 0x74, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x70, 0x70, + 0x65, 0x72, 0x5f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x01, 0x52, 0x0a, + 0x75, 0x70, 0x70, 0x65, 0x72, 0x42, 0x6f, 0x75, 0x6e, 0x64, 0x12, 0x3a, 0x0a, 0x08, 0x65, 0x78, + 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x69, + 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x78, 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x72, 0x52, 0x08, 0x65, 0x78, + 0x65, 0x6d, 0x70, 0x6c, 0x61, 0x72, 0x22, 0x3c, 0x0a, 0x0a, 0x42, 0x75, 0x63, 0x6b, 0x65, 0x74, + 0x53, 0x70, 0x61, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x11, 0x52, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, + 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x6c, 0x65, + 0x6e, 0x67, 0x74, 0x68, 0x22, 0x91, 0x01, 0x0a, 0x08, 0x45, 0x78, 0x65, 0x6d, 0x70, 0x6c, 0x61, + 0x72, 0x12, 0x35, 0x0a, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x1f, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, + 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x50, 0x61, 0x69, + 0x72, 0x52, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x01, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x38, + 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, + 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x22, 0xff, 0x02, 0x0a, 0x06, 0x4d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x12, 0x35, 0x0a, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, + 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x50, + 0x61, 0x69, 0x72, 0x52, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x31, 0x0a, 0x05, 0x67, 0x61, + 0x75, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x69, 0x6f, 0x2e, 0x70, + 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x2e, 0x47, 0x61, 0x75, 0x67, 0x65, 0x52, 0x05, 0x67, 0x61, 0x75, 0x67, 0x65, 0x12, 0x37, 0x0a, + 0x07, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, + 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, + 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x52, 0x07, 0x63, + 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x37, 0x0a, 0x07, 0x73, 0x75, 0x6d, 0x6d, 0x61, 0x72, + 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, + 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x53, + 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x52, 0x07, 0x73, 0x75, 0x6d, 0x6d, 0x61, 0x72, 0x79, 0x12, + 0x37, 0x0a, 0x07, 0x75, 0x6e, 0x74, 0x79, 0x70, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1d, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, + 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x55, 0x6e, 0x74, 0x79, 0x70, 0x65, 0x64, 0x52, + 0x07, 0x75, 0x6e, 0x74, 0x79, 0x70, 0x65, 0x64, 0x12, 0x3d, 0x0a, 0x09, 0x68, 0x69, 0x73, 0x74, + 0x6f, 0x67, 0x72, 0x61, 0x6d, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x69, 0x6f, + 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x2e, 0x48, 0x69, 0x73, 0x74, 0x6f, 0x67, 0x72, 0x61, 0x6d, 0x52, 0x09, 0x68, 0x69, + 0x73, 0x74, 0x6f, 0x67, 0x72, 0x61, 0x6d, 0x12, 0x21, 0x0a, 0x0c, 0x74, 0x69, 0x6d, 0x65, 0x73, + 0x74, 0x61, 0x6d, 0x70, 0x5f, 0x6d, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x74, + 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x4d, 0x73, 0x22, 0xa2, 0x01, 0x0a, 0x0c, 0x4d, + 0x65, 0x74, 0x72, 0x69, 0x63, 0x46, 0x61, 0x6d, 0x69, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x68, 0x65, 0x6c, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, + 0x65, 0x6c, 0x70, 0x12, 0x34, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x20, 0x2e, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, + 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, + 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x34, 0x0a, 0x06, 0x6d, 0x65, 0x74, + 0x72, 0x69, 0x63, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x69, 0x6f, 0x2e, 0x70, + 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x2e, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x2a, + 0x62, 0x0a, 0x0a, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, 0x0a, + 0x07, 0x43, 0x4f, 0x55, 0x4e, 0x54, 0x45, 0x52, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x47, 0x41, + 0x55, 0x47, 0x45, 0x10, 0x01, 0x12, 0x0b, 0x0a, 0x07, 0x53, 0x55, 0x4d, 0x4d, 0x41, 0x52, 0x59, + 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x54, 0x59, 0x50, 0x45, 0x44, 0x10, 0x03, 0x12, + 0x0d, 0x0a, 0x09, 0x48, 0x49, 0x53, 0x54, 0x4f, 0x47, 0x52, 0x41, 0x4d, 0x10, 0x04, 0x12, 0x13, + 0x0a, 0x0f, 0x47, 0x41, 0x55, 0x47, 0x45, 0x5f, 0x48, 0x49, 0x53, 0x54, 0x4f, 0x47, 0x52, 0x41, + 0x4d, 0x10, 0x05, 0x42, 0x52, 0x0a, 0x14, 0x69, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, + 0x68, 0x65, 0x75, 0x73, 0x2e, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5a, 0x3a, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, + 0x75, 0x73, 0x2f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x2f, + 0x67, 0x6f, 0x3b, 0x69, 0x6f, 0x5f, 0x70, 0x72, 0x6f, 0x6d, 0x65, 0x74, 0x68, 0x65, 0x75, 0x73, + 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, +} + +var ( + file_io_prometheus_client_metrics_proto_rawDescOnce sync.Once + file_io_prometheus_client_metrics_proto_rawDescData = file_io_prometheus_client_metrics_proto_rawDesc +) + +func file_io_prometheus_client_metrics_proto_rawDescGZIP() []byte { + file_io_prometheus_client_metrics_proto_rawDescOnce.Do(func() { + file_io_prometheus_client_metrics_proto_rawDescData = protoimpl.X.CompressGZIP(file_io_prometheus_client_metrics_proto_rawDescData) + }) + return file_io_prometheus_client_metrics_proto_rawDescData +} + +var file_io_prometheus_client_metrics_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_io_prometheus_client_metrics_proto_msgTypes = make([]protoimpl.MessageInfo, 12) +var file_io_prometheus_client_metrics_proto_goTypes = []interface{}{ + (MetricType)(0), // 0: io.prometheus.client.MetricType + (*LabelPair)(nil), // 1: io.prometheus.client.LabelPair + (*Gauge)(nil), // 2: io.prometheus.client.Gauge + (*Counter)(nil), // 3: io.prometheus.client.Counter + (*Quantile)(nil), // 4: io.prometheus.client.Quantile + (*Summary)(nil), // 5: io.prometheus.client.Summary + (*Untyped)(nil), // 6: io.prometheus.client.Untyped + (*Histogram)(nil), // 7: io.prometheus.client.Histogram + (*Bucket)(nil), // 8: io.prometheus.client.Bucket + (*BucketSpan)(nil), // 9: io.prometheus.client.BucketSpan + (*Exemplar)(nil), // 10: io.prometheus.client.Exemplar + (*Metric)(nil), // 11: io.prometheus.client.Metric + (*MetricFamily)(nil), // 12: io.prometheus.client.MetricFamily + (*timestamppb.Timestamp)(nil), // 13: google.protobuf.Timestamp +} +var file_io_prometheus_client_metrics_proto_depIdxs = []int32{ + 10, // 0: io.prometheus.client.Counter.exemplar:type_name -> io.prometheus.client.Exemplar + 4, // 1: io.prometheus.client.Summary.quantile:type_name -> io.prometheus.client.Quantile + 8, // 2: io.prometheus.client.Histogram.bucket:type_name -> io.prometheus.client.Bucket + 9, // 3: io.prometheus.client.Histogram.negative_span:type_name -> io.prometheus.client.BucketSpan + 9, // 4: io.prometheus.client.Histogram.positive_span:type_name -> io.prometheus.client.BucketSpan + 10, // 5: io.prometheus.client.Bucket.exemplar:type_name -> io.prometheus.client.Exemplar + 1, // 6: io.prometheus.client.Exemplar.label:type_name -> io.prometheus.client.LabelPair + 13, // 7: io.prometheus.client.Exemplar.timestamp:type_name -> google.protobuf.Timestamp + 1, // 8: io.prometheus.client.Metric.label:type_name -> io.prometheus.client.LabelPair + 2, // 9: io.prometheus.client.Metric.gauge:type_name -> io.prometheus.client.Gauge + 3, // 10: io.prometheus.client.Metric.counter:type_name -> io.prometheus.client.Counter + 5, // 11: io.prometheus.client.Metric.summary:type_name -> io.prometheus.client.Summary + 6, // 12: io.prometheus.client.Metric.untyped:type_name -> io.prometheus.client.Untyped + 7, // 13: io.prometheus.client.Metric.histogram:type_name -> io.prometheus.client.Histogram + 0, // 14: io.prometheus.client.MetricFamily.type:type_name -> io.prometheus.client.MetricType + 11, // 15: io.prometheus.client.MetricFamily.metric:type_name -> io.prometheus.client.Metric + 16, // [16:16] is the sub-list for method output_type + 16, // [16:16] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name +} + +func init() { file_io_prometheus_client_metrics_proto_init() } +func file_io_prometheus_client_metrics_proto_init() { + if File_io_prometheus_client_metrics_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_io_prometheus_client_metrics_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*LabelPair); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Gauge); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Counter); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Quantile); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Summary); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Untyped); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Histogram); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Bucket); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*BucketSpan); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Exemplar); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Metric); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_io_prometheus_client_metrics_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MetricFamily); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_io_prometheus_client_metrics_proto_rawDesc, + NumEnums: 1, + NumMessages: 12, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_io_prometheus_client_metrics_proto_goTypes, + DependencyIndexes: file_io_prometheus_client_metrics_proto_depIdxs, + EnumInfos: file_io_prometheus_client_metrics_proto_enumTypes, + MessageInfos: file_io_prometheus_client_metrics_proto_msgTypes, + }.Build() + File_io_prometheus_client_metrics_proto = out.File + file_io_prometheus_client_metrics_proto_rawDesc = nil + file_io_prometheus_client_metrics_proto_goTypes = nil + file_io_prometheus_client_metrics_proto_depIdxs = nil +} diff --git a/vendor/github.com/rabbitmq/amqp091-go/CHANGELOG.md b/vendor/github.com/rabbitmq/amqp091-go/CHANGELOG.md new file mode 100644 index 0000000000..02523c2522 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/CHANGELOG.md @@ -0,0 +1,283 @@ +# Changelog + +## [v1.8.0](https://github.com/rabbitmq/amqp091-go/tree/v1.8.0) (2023-03-21) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.7.0...v1.8.0) + +**Closed issues:** + +- memory leak [\#179](https://github.com/rabbitmq/amqp091-go/issues/179) +- the publishWithContext interface will not return when it times out [\#178](https://github.com/rabbitmq/amqp091-go/issues/178) + +**Merged pull requests:** + +- Fix race condition on confirms [\#183](https://github.com/rabbitmq/amqp091-go/pull/183) ([calloway-jacob](https://github.com/calloway-jacob)) +- Add a CloseDeadline function to Connection [\#181](https://github.com/rabbitmq/amqp091-go/pull/181) ([Zerpet](https://github.com/Zerpet)) +- Fix memory leaks [\#180](https://github.com/rabbitmq/amqp091-go/pull/180) ([GXKe](https://github.com/GXKe)) +- Bump go.uber.org/goleak from 1.2.0 to 1.2.1 [\#177](https://github.com/rabbitmq/amqp091-go/pull/177) ([dependabot[bot]](https://github.com/apps/dependabot)) + +## [v1.7.0](https://github.com/rabbitmq/amqp091-go/tree/v1.7.0) (2023-02-09) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.6.1...v1.7.0) + +**Closed issues:** + +- \#31 resurfacing \(?\) [\#170](https://github.com/rabbitmq/amqp091-go/issues/170) +- Deprecate QueueInspect [\#167](https://github.com/rabbitmq/amqp091-go/issues/167) +- v1.6.0 causing rabbit connection errors [\#160](https://github.com/rabbitmq/amqp091-go/issues/160) + +**Merged pull requests:** + +- Set channels and allocator to nil in shutdown [\#172](https://github.com/rabbitmq/amqp091-go/pull/172) ([lukebakken](https://github.com/lukebakken)) +- Fix racing in Open [\#171](https://github.com/rabbitmq/amqp091-go/pull/171) ([Zerpet](https://github.com/Zerpet)) +- adding go 1.20 to tests [\#169](https://github.com/rabbitmq/amqp091-go/pull/169) ([halilylm](https://github.com/halilylm)) +- Deprecate the QueueInspect function [\#168](https://github.com/rabbitmq/amqp091-go/pull/168) ([lukebakken](https://github.com/lukebakken)) +- Check if channel is nil before updating it [\#150](https://github.com/rabbitmq/amqp091-go/pull/150) ([julienschmidt](https://github.com/julienschmidt)) + +## [v1.6.1](https://github.com/rabbitmq/amqp091-go/tree/v1.6.1) (2023-02-01) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.6.1-rc.2...v1.6.1) + +**Merged pull requests:** + +- Update Makefile targets related to RabbitMQ [\#163](https://github.com/rabbitmq/amqp091-go/pull/163) ([Zerpet](https://github.com/Zerpet)) + +## [v1.6.1-rc.2](https://github.com/rabbitmq/amqp091-go/tree/v1.6.1-rc.2) (2023-01-31) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.6.1-rc.1...v1.6.1-rc.2) + +**Merged pull requests:** + +- Do not overly protect writes [\#162](https://github.com/rabbitmq/amqp091-go/pull/162) ([lukebakken](https://github.com/lukebakken)) + +## [v1.6.1-rc.1](https://github.com/rabbitmq/amqp091-go/tree/v1.6.1-rc.1) (2023-01-31) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.6.0...v1.6.1-rc.1) + +**Closed issues:** + +- Calling Channel\(\) on an empty connection panics [\#148](https://github.com/rabbitmq/amqp091-go/issues/148) + +**Merged pull requests:** + +- Ensure flush happens and correctly lock connection for a series of unflushed writes [\#161](https://github.com/rabbitmq/amqp091-go/pull/161) ([lukebakken](https://github.com/lukebakken)) + +## [v1.6.0](https://github.com/rabbitmq/amqp091-go/tree/v1.6.0) (2023-01-20) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.5.0...v1.6.0) + +**Implemented enhancements:** + +- Add constants for Queue arguments [\#145](https://github.com/rabbitmq/amqp091-go/pull/145) ([Zerpet](https://github.com/Zerpet)) + +**Closed issues:** + +- README not up to date [\#154](https://github.com/rabbitmq/amqp091-go/issues/154) +- Allow re-using default connection config \(custom properties\) [\#152](https://github.com/rabbitmq/amqp091-go/issues/152) +- Rename package name to amqp in V2 [\#151](https://github.com/rabbitmq/amqp091-go/issues/151) +- Helper types to declare quorum queues [\#144](https://github.com/rabbitmq/amqp091-go/issues/144) +- Inefficient use of buffers reduces potential throughput for basicPublish with small messages. [\#141](https://github.com/rabbitmq/amqp091-go/issues/141) +- bug, close cause panic [\#130](https://github.com/rabbitmq/amqp091-go/issues/130) +- Publishing Headers are unable to store Table with slice values [\#125](https://github.com/rabbitmq/amqp091-go/issues/125) +- Example client can deadlock in Close due to unconsumed confirmations [\#122](https://github.com/rabbitmq/amqp091-go/issues/122) +- SAC not working properly [\#106](https://github.com/rabbitmq/amqp091-go/issues/106) + +**Merged pull requests:** + +- Add automatic CHANGELOG.md generation [\#158](https://github.com/rabbitmq/amqp091-go/pull/158) ([lukebakken](https://github.com/lukebakken)) +- Supply library-defined props with NewConnectionProperties [\#157](https://github.com/rabbitmq/amqp091-go/pull/157) ([slagiewka](https://github.com/slagiewka)) +- Fix linter warnings [\#156](https://github.com/rabbitmq/amqp091-go/pull/156) ([Zerpet](https://github.com/Zerpet)) +- Remove outdated information from README [\#155](https://github.com/rabbitmq/amqp091-go/pull/155) ([scriptcoded](https://github.com/scriptcoded)) +- Add example producer using DeferredConfirm [\#149](https://github.com/rabbitmq/amqp091-go/pull/149) ([Zerpet](https://github.com/Zerpet)) +- Ensure code is formatted [\#147](https://github.com/rabbitmq/amqp091-go/pull/147) ([lukebakken](https://github.com/lukebakken)) +- Fix inefficient use of buffers that reduces the potential throughput of basicPublish [\#142](https://github.com/rabbitmq/amqp091-go/pull/142) ([fadams](https://github.com/fadams)) +- Do not embed context in DeferredConfirmation [\#140](https://github.com/rabbitmq/amqp091-go/pull/140) ([tie](https://github.com/tie)) +- Add constant for default exchange [\#139](https://github.com/rabbitmq/amqp091-go/pull/139) ([marlongerson](https://github.com/marlongerson)) +- Fix indentation and remove unnecessary instructions [\#138](https://github.com/rabbitmq/amqp091-go/pull/138) ([alraujo](https://github.com/alraujo)) +- Remove unnecessary instruction [\#135](https://github.com/rabbitmq/amqp091-go/pull/135) ([alraujo](https://github.com/alraujo)) +- Fix example client to avoid deadlock in Close [\#123](https://github.com/rabbitmq/amqp091-go/pull/123) ([Zerpet](https://github.com/Zerpet)) +- Bump go.uber.org/goleak from 1.1.12 to 1.2.0 [\#116](https://github.com/rabbitmq/amqp091-go/pull/116) ([dependabot[bot]](https://github.com/apps/dependabot)) + +## [v1.5.0](https://github.com/rabbitmq/amqp091-go/tree/v1.5.0) (2022-09-07) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.4.0...v1.5.0) + +**Implemented enhancements:** + +- Provide a friendly way to set connection name [\#105](https://github.com/rabbitmq/amqp091-go/issues/105) + +**Closed issues:** + +- Support connection.update-secret [\#107](https://github.com/rabbitmq/amqp091-go/issues/107) +- Example Client: Implementation of a Consumer with reconnection support [\#40](https://github.com/rabbitmq/amqp091-go/issues/40) + +**Merged pull requests:** + +- use PublishWithContext instead of Publish [\#115](https://github.com/rabbitmq/amqp091-go/pull/115) ([Gsantomaggio](https://github.com/Gsantomaggio)) +- Add support for connection.update-secret [\#114](https://github.com/rabbitmq/amqp091-go/pull/114) ([Zerpet](https://github.com/Zerpet)) +- Remove warning on RabbitMQ tutorials in go [\#113](https://github.com/rabbitmq/amqp091-go/pull/113) ([ChunyiLyu](https://github.com/ChunyiLyu)) +- Update AMQP Spec [\#110](https://github.com/rabbitmq/amqp091-go/pull/110) ([Zerpet](https://github.com/Zerpet)) +- Add an example of reliable consumer [\#109](https://github.com/rabbitmq/amqp091-go/pull/109) ([Zerpet](https://github.com/Zerpet)) +- Add convenience function to set connection name [\#108](https://github.com/rabbitmq/amqp091-go/pull/108) ([Zerpet](https://github.com/Zerpet)) + +## [v1.4.0](https://github.com/rabbitmq/amqp091-go/tree/v1.4.0) (2022-07-19) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.3.4...v1.4.0) + +**Closed issues:** + +- target machine actively refused connection [\#99](https://github.com/rabbitmq/amqp091-go/issues/99) +- 504 channel/connection is not open error occurred in multiple connection with same rabbitmq service [\#97](https://github.com/rabbitmq/amqp091-go/issues/97) +- Add possible cancel of DeferredConfirmation [\#92](https://github.com/rabbitmq/amqp091-go/issues/92) +- Documentation [\#89](https://github.com/rabbitmq/amqp091-go/issues/89) +- Channel Close gets stuck after closing a connection \(via management UI\) [\#88](https://github.com/rabbitmq/amqp091-go/issues/88) +- this library has same issue [\#83](https://github.com/rabbitmq/amqp091-go/issues/83) +- Provide a logging interface [\#81](https://github.com/rabbitmq/amqp091-go/issues/81) +- 1.4.0 release checklist [\#77](https://github.com/rabbitmq/amqp091-go/issues/77) +- Data race in the client example [\#72](https://github.com/rabbitmq/amqp091-go/issues/72) +- reader go routine hangs and leaks when Connection.Close\(\) is called multiple times [\#69](https://github.com/rabbitmq/amqp091-go/issues/69) +- Support auto-reconnect and cluster [\#65](https://github.com/rabbitmq/amqp091-go/issues/65) +- Connection/Channel Deadlock [\#32](https://github.com/rabbitmq/amqp091-go/issues/32) +- Closing connection and/or channel hangs NotifyPublish is used [\#21](https://github.com/rabbitmq/amqp091-go/issues/21) +- Consumer channel isn't closed in the event of unexpected disconnection [\#18](https://github.com/rabbitmq/amqp091-go/issues/18) + +**Merged pull requests:** + +- fix race condition with context close and confirm at the same time on DeferredConfirmation. [\#101](https://github.com/rabbitmq/amqp091-go/pull/101) ([sapk](https://github.com/sapk)) +- Add build TLS config from URI [\#98](https://github.com/rabbitmq/amqp091-go/pull/98) ([reddec](https://github.com/reddec)) +- Use context for Publish methods [\#96](https://github.com/rabbitmq/amqp091-go/pull/96) ([sapk](https://github.com/sapk)) +- Added function to get the remote peer's IP address \(conn.RemoteAddr\(\)\) [\#95](https://github.com/rabbitmq/amqp091-go/pull/95) ([rabb1t](https://github.com/rabb1t)) +- Update connection documentation [\#90](https://github.com/rabbitmq/amqp091-go/pull/90) ([Zerpet](https://github.com/Zerpet)) +- Revert test to demonstrate actual bug [\#87](https://github.com/rabbitmq/amqp091-go/pull/87) ([lukebakken](https://github.com/lukebakken)) +- Minor improvements to examples [\#86](https://github.com/rabbitmq/amqp091-go/pull/86) ([lukebakken](https://github.com/lukebakken)) +- Do not skip flaky test in CI [\#85](https://github.com/rabbitmq/amqp091-go/pull/85) ([lukebakken](https://github.com/lukebakken)) +- Add logging [\#84](https://github.com/rabbitmq/amqp091-go/pull/84) ([lukebakken](https://github.com/lukebakken)) +- Add a win32 build [\#82](https://github.com/rabbitmq/amqp091-go/pull/82) ([lukebakken](https://github.com/lukebakken)) +- channel: return nothing instead of always a nil-error in receive methods [\#80](https://github.com/rabbitmq/amqp091-go/pull/80) ([fho](https://github.com/fho)) +- update the contributing & readme files, improve makefile [\#79](https://github.com/rabbitmq/amqp091-go/pull/79) ([fho](https://github.com/fho)) +- Fix lint errors [\#78](https://github.com/rabbitmq/amqp091-go/pull/78) ([lukebakken](https://github.com/lukebakken)) +- ci: run golangci-lint [\#76](https://github.com/rabbitmq/amqp091-go/pull/76) ([fho](https://github.com/fho)) +- ci: run test via make & remove travis CI config [\#75](https://github.com/rabbitmq/amqp091-go/pull/75) ([fho](https://github.com/fho)) +- ci: run tests with race detector [\#74](https://github.com/rabbitmq/amqp091-go/pull/74) ([fho](https://github.com/fho)) +- Detect go routine leaks in integration testcases [\#73](https://github.com/rabbitmq/amqp091-go/pull/73) ([fho](https://github.com/fho)) +- connection: fix: reader go-routine is leaked on connection close [\#70](https://github.com/rabbitmq/amqp091-go/pull/70) ([fho](https://github.com/fho)) +- adding best practises for NotifyPublish for issue\_21 scenario [\#68](https://github.com/rabbitmq/amqp091-go/pull/68) ([DanielePalaia](https://github.com/DanielePalaia)) +- Update Go version [\#67](https://github.com/rabbitmq/amqp091-go/pull/67) ([Zerpet](https://github.com/Zerpet)) +- Regenerate certs with SHA256 to fix test with Go 1.18+ [\#66](https://github.com/rabbitmq/amqp091-go/pull/66) ([anthonyfok](https://github.com/anthonyfok)) + +## [v1.3.4](https://github.com/rabbitmq/amqp091-go/tree/v1.3.4) (2022-04-01) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.3.3...v1.3.4) + +**Merged pull requests:** + +- bump version to 1.3.4 [\#63](https://github.com/rabbitmq/amqp091-go/pull/63) ([DanielePalaia](https://github.com/DanielePalaia)) +- updating doc [\#62](https://github.com/rabbitmq/amqp091-go/pull/62) ([DanielePalaia](https://github.com/DanielePalaia)) + +## [v1.3.3](https://github.com/rabbitmq/amqp091-go/tree/v1.3.3) (2022-04-01) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.3.2...v1.3.3) + +**Closed issues:** + +- Add Client Version [\#49](https://github.com/rabbitmq/amqp091-go/issues/49) +- OpenTelemetry Propagation [\#22](https://github.com/rabbitmq/amqp091-go/issues/22) + +**Merged pull requests:** + +- bump buildVersion for release [\#61](https://github.com/rabbitmq/amqp091-go/pull/61) ([DanielePalaia](https://github.com/DanielePalaia)) +- adding documentation for notifyClose best pratices [\#60](https://github.com/rabbitmq/amqp091-go/pull/60) ([DanielePalaia](https://github.com/DanielePalaia)) +- adding documentation on NotifyClose of connection and channel to enfo… [\#59](https://github.com/rabbitmq/amqp091-go/pull/59) ([DanielePalaia](https://github.com/DanielePalaia)) + +## [v1.3.2](https://github.com/rabbitmq/amqp091-go/tree/v1.3.2) (2022-03-28) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.3.1...v1.3.2) + +**Closed issues:** + +- Potential race condition in Connection module [\#31](https://github.com/rabbitmq/amqp091-go/issues/31) + +**Merged pull requests:** + +- bump versioning to 1.3.2 [\#58](https://github.com/rabbitmq/amqp091-go/pull/58) ([DanielePalaia](https://github.com/DanielePalaia)) + +## [v1.3.1](https://github.com/rabbitmq/amqp091-go/tree/v1.3.1) (2022-03-25) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.3.0...v1.3.1) + +**Closed issues:** + +- Possible deadlock on DeferredConfirmation.Wait\(\) [\#46](https://github.com/rabbitmq/amqp091-go/issues/46) +- Call to Delivery.Ack blocks indefinitely in case of disconnection [\#19](https://github.com/rabbitmq/amqp091-go/issues/19) +- Unexpacted behavor of channel.IsClosed\(\) [\#14](https://github.com/rabbitmq/amqp091-go/issues/14) +- A possible dead lock in connection close notification Go channel [\#11](https://github.com/rabbitmq/amqp091-go/issues/11) + +**Merged pull requests:** + +- These ones were the ones testing Open scenarios. The issue is that Op… [\#57](https://github.com/rabbitmq/amqp091-go/pull/57) ([DanielePalaia](https://github.com/DanielePalaia)) +- changing defaultVersion to buildVersion and create a simple change\_ve… [\#54](https://github.com/rabbitmq/amqp091-go/pull/54) ([DanielePalaia](https://github.com/DanielePalaia)) +- adding integration test for issue 11 [\#50](https://github.com/rabbitmq/amqp091-go/pull/50) ([DanielePalaia](https://github.com/DanielePalaia)) +- Remove the old link product [\#48](https://github.com/rabbitmq/amqp091-go/pull/48) ([Gsantomaggio](https://github.com/Gsantomaggio)) +- Fix deadlock on DeferredConfirmations [\#47](https://github.com/rabbitmq/amqp091-go/pull/47) ([SpencerTorres](https://github.com/SpencerTorres)) +- Example client: Rename Stream\(\) to Consume\(\) to avoid confusion with RabbitMQ streams [\#39](https://github.com/rabbitmq/amqp091-go/pull/39) ([andygrunwald](https://github.com/andygrunwald)) +- Example client: Rename `name` to `queueName` to make the usage clear and explicit [\#38](https://github.com/rabbitmq/amqp091-go/pull/38) ([andygrunwald](https://github.com/andygrunwald)) +- Client example: Renamed concept "Session" to "Client" [\#37](https://github.com/rabbitmq/amqp091-go/pull/37) ([andygrunwald](https://github.com/andygrunwald)) +- delete unuseful code [\#36](https://github.com/rabbitmq/amqp091-go/pull/36) ([liutaot](https://github.com/liutaot)) +- Client Example: Fix closing order [\#35](https://github.com/rabbitmq/amqp091-go/pull/35) ([andygrunwald](https://github.com/andygrunwald)) +- Client example: Use instance logger instead of global logger [\#34](https://github.com/rabbitmq/amqp091-go/pull/34) ([andygrunwald](https://github.com/andygrunwald)) + +## [v1.3.0](https://github.com/rabbitmq/amqp091-go/tree/v1.3.0) (2022-01-13) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.2.0...v1.3.0) + +**Closed issues:** + +- documentation of changes triggering version updates [\#29](https://github.com/rabbitmq/amqp091-go/issues/29) +- Persistent messages folder [\#27](https://github.com/rabbitmq/amqp091-go/issues/27) + +**Merged pull requests:** + +- Expose a method to enable out-of-order Publisher Confirms [\#33](https://github.com/rabbitmq/amqp091-go/pull/33) ([benmoss](https://github.com/benmoss)) +- Fix Signed 8-bit headers being treated as unsigned [\#26](https://github.com/rabbitmq/amqp091-go/pull/26) ([alex-goodisman](https://github.com/alex-goodisman)) + +## [v1.2.0](https://github.com/rabbitmq/amqp091-go/tree/v1.2.0) (2021-11-17) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/v1.1.0...v1.2.0) + +**Closed issues:** + +- No access to this vhost [\#24](https://github.com/rabbitmq/amqp091-go/issues/24) +- copyright issue? [\#12](https://github.com/rabbitmq/amqp091-go/issues/12) +- A possible dead lock when publishing message with confirmation [\#10](https://github.com/rabbitmq/amqp091-go/issues/10) +- Semver release [\#7](https://github.com/rabbitmq/amqp091-go/issues/7) + +**Merged pull requests:** + +- Fix deadlock between publishing and receiving confirms [\#25](https://github.com/rabbitmq/amqp091-go/pull/25) ([benmoss](https://github.com/benmoss)) +- Add GetNextPublishSeqNo for channel in confirm mode [\#23](https://github.com/rabbitmq/amqp091-go/pull/23) ([kamal-github](https://github.com/kamal-github)) +- Added support for cert-only login without user and password [\#20](https://github.com/rabbitmq/amqp091-go/pull/20) ([mihaitodor](https://github.com/mihaitodor)) + +## [v1.1.0](https://github.com/rabbitmq/amqp091-go/tree/v1.1.0) (2021-09-21) + +[Full Changelog](https://github.com/rabbitmq/amqp091-go/compare/ebd83429aa8cb06fa569473f623e87675f96d3a9...v1.1.0) + +**Closed issues:** + +- AMQPLAIN authentication does not work [\#15](https://github.com/rabbitmq/amqp091-go/issues/15) + +**Merged pull requests:** + +- Fix AMQPLAIN authentication mechanism [\#16](https://github.com/rabbitmq/amqp091-go/pull/16) ([hodbn](https://github.com/hodbn)) +- connection: clarify documented behavior of NotifyClose [\#13](https://github.com/rabbitmq/amqp091-go/pull/13) ([pabigot](https://github.com/pabigot)) +- Add a link to pkg.go.dev API docs [\#9](https://github.com/rabbitmq/amqp091-go/pull/9) ([benmoss](https://github.com/benmoss)) +- add test go version 1.16.x and 1.17.x [\#8](https://github.com/rabbitmq/amqp091-go/pull/8) ([k4n4ry](https://github.com/k4n4ry)) +- fix typos [\#6](https://github.com/rabbitmq/amqp091-go/pull/6) ([h44z](https://github.com/h44z)) +- Heartbeat interval should be timeout/2 [\#5](https://github.com/rabbitmq/amqp091-go/pull/5) ([ifo20](https://github.com/ifo20)) +- Exporting Channel State [\#4](https://github.com/rabbitmq/amqp091-go/pull/4) ([eibrunorodrigues](https://github.com/eibrunorodrigues)) +- Add codeql analysis [\#3](https://github.com/rabbitmq/amqp091-go/pull/3) ([MirahImage](https://github.com/MirahImage)) +- Add PR github action. [\#2](https://github.com/rabbitmq/amqp091-go/pull/2) ([MirahImage](https://github.com/MirahImage)) +- Update Copyright Statement [\#1](https://github.com/rabbitmq/amqp091-go/pull/1) ([rlewis24](https://github.com/rlewis24)) + + + +\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* diff --git a/vendor/github.com/rabbitmq/amqp091-go/Makefile b/vendor/github.com/rabbitmq/amqp091-go/Makefile new file mode 100644 index 0000000000..69e9e2be12 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/Makefile @@ -0,0 +1,41 @@ +.DEFAULT_GOAL := list + +# Insert a comment starting with '##' after a target, and it will be printed by 'make' and 'make list' +.PHONY: list +list: ## list Makefile targets + @echo "The most used targets: \n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +.PHONY: check-fmt +check-fmt: ## Ensure code is formatted + gofmt -l -d . # For the sake of debugging + test -z "$$(gofmt -l .)" + +.PHONY: fmt +fmt: ## Run go fmt against code + go fmt ./... + +.PHONY: tests +tests: ## Run all tests and requires a running rabbitmq-server. Use GO_TEST_FLAGS to add extra flags to go test + go test -race -v -tags integration $(GO_TEST_FLAGS) + +.PHONY: tests-docker +tests-docker: rabbitmq-server + RABBITMQ_RABBITMQCTL_PATH="DOCKER:$(CONTAINER_NAME)" go test -race -v -tags integration $(GO_TEST_FLAGS) + $(MAKE) stop-rabbitmq-server + +.PHONY: check +check: + golangci-lint run ./... + +CONTAINER_NAME ?= amqp091-go-rabbitmq + +.PHONY: rabbitmq-server +rabbitmq-server: ## Start a RabbitMQ server using Docker. Container name can be customised with CONTAINER_NAME=some-rabbit + docker run --detach --rm --name $(CONTAINER_NAME) \ + --publish 5672:5672 --publish 15672:15672 \ + --pull always rabbitmq:3-management + +.PHONY: stop-rabbitmq-server +stop-rabbitmq-server: ## Stop a RabbitMQ server using Docker. Container name can be customised with CONTAINER_NAME=some-rabbit + docker stop $(CONTAINER_NAME) diff --git a/vendor/github.com/rabbitmq/amqp091-go/RELEASE.md b/vendor/github.com/rabbitmq/amqp091-go/RELEASE.md new file mode 100644 index 0000000000..a1b1ae0c3c --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/RELEASE.md @@ -0,0 +1,5 @@ +## Changelog Generation + +``` +github_changelog_generator --token GITHUB-TOKEN -u rabbitmq -p amqp091-go --no-unreleased --release-branch main +``` diff --git a/vendor/github.com/rabbitmq/amqp091-go/allocator.go b/vendor/github.com/rabbitmq/amqp091-go/allocator.go new file mode 100644 index 0000000000..0688e4b643 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/allocator.go @@ -0,0 +1,111 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All Rights Reserved. +// Copyright (c) 2012-2021, Sean Treadway, SoundCloud Ltd. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package amqp091 + +import ( + "bytes" + "fmt" + "math/big" +) + +const ( + free = 0 + allocated = 1 +) + +// allocator maintains a bitset of allocated numbers. +type allocator struct { + pool *big.Int + last int + low int + high int +} + +// NewAllocator reserves and frees integers out of a range between low and +// high. +// +// O(N) worst case space used, where N is maximum allocated, divided by +// sizeof(big.Word) +func newAllocator(low, high int) *allocator { + return &allocator{ + pool: big.NewInt(0), + last: low, + low: low, + high: high, + } +} + +// String returns a string describing the contents of the allocator like +// "allocator[low..high] reserved..until" +// +// O(N) where N is high-low +func (a allocator) String() string { + b := &bytes.Buffer{} + fmt.Fprintf(b, "allocator[%d..%d]", a.low, a.high) + + for low := a.low; low <= a.high; low++ { + high := low + for a.reserved(high) && high <= a.high { + high++ + } + + if high > low+1 { + fmt.Fprintf(b, " %d..%d", low, high-1) + } else if high > low { + fmt.Fprintf(b, " %d", high-1) + } + + low = high + } + return b.String() +} + +// Next reserves and returns the next available number out of the range between +// low and high. If no number is available, false is returned. +// +// O(N) worst case runtime where N is allocated, but usually O(1) due to a +// rolling index into the oldest allocation. +func (a *allocator) next() (int, bool) { + wrapped := a.last + + // Find trailing bit + for ; a.last <= a.high; a.last++ { + if a.reserve(a.last) { + return a.last, true + } + } + + // Find preceding free'd pool + a.last = a.low + + for ; a.last < wrapped; a.last++ { + if a.reserve(a.last) { + return a.last, true + } + } + + return 0, false +} + +// reserve claims the bit if it is not already claimed, returning true if +// successfully claimed. +func (a *allocator) reserve(n int) bool { + if a.reserved(n) { + return false + } + a.pool.SetBit(a.pool, n-a.low, allocated) + return true +} + +// reserved returns true if the integer has been allocated +func (a *allocator) reserved(n int) bool { + return a.pool.Bit(n-a.low) == allocated +} + +// release frees the use of the number for another allocation +func (a *allocator) release(n int) { + a.pool.SetBit(a.pool, n-a.low, free) +} diff --git a/vendor/github.com/rabbitmq/amqp091-go/certs.sh b/vendor/github.com/rabbitmq/amqp091-go/certs.sh new file mode 100644 index 0000000000..403e80c544 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/certs.sh @@ -0,0 +1,159 @@ +#!/bin/sh +# +# Creates the CA, server and client certs to be used by tls_test.go +# http://www.rabbitmq.com/ssl.html +# +# Copy stdout into the const section of tls_test.go or use for RabbitMQ +# +root=$PWD/certs + +if [ -f $root/ca/serial ]; then + echo >&2 "Previous installation found" + echo >&2 "Remove $root/ca and rerun to overwrite" + exit 1 +fi + +mkdir -p $root/ca/private +mkdir -p $root/ca/certs +mkdir -p $root/server +mkdir -p $root/client + +cd $root/ca + +chmod 700 private +touch index.txt +echo 'unique_subject = no' > index.txt.attr +echo '01' > serial +echo >openssl.cnf ' +[ ca ] +default_ca = testca + +[ testca ] +dir = . +certificate = $dir/cacert.pem +database = $dir/index.txt +new_certs_dir = $dir/certs +private_key = $dir/private/cakey.pem +serial = $dir/serial + +default_crl_days = 7 +default_days = 3650 +default_md = sha256 + +policy = testca_policy +x509_extensions = certificate_extensions + +[ testca_policy ] +commonName = supplied +stateOrProvinceName = optional +countryName = optional +emailAddress = optional +organizationName = optional +organizationalUnitName = optional + +[ certificate_extensions ] +basicConstraints = CA:false + +[ req ] +default_bits = 2048 +default_keyfile = ./private/cakey.pem +default_md = sha256 +prompt = yes +distinguished_name = root_ca_distinguished_name +x509_extensions = root_ca_extensions + +[ root_ca_distinguished_name ] +commonName = hostname + +[ root_ca_extensions ] +basicConstraints = CA:true +keyUsage = keyCertSign, cRLSign + +[ client_ca_extensions ] +basicConstraints = CA:false +keyUsage = digitalSignature +extendedKeyUsage = 1.3.6.1.5.5.7.3.2 + +[ server_ca_extensions ] +basicConstraints = CA:false +keyUsage = keyEncipherment +extendedKeyUsage = 1.3.6.1.5.5.7.3.1 +subjectAltName = @alt_names + +[ alt_names ] +IP.1 = 127.0.0.1 +' + +openssl req \ + -x509 \ + -nodes \ + -config openssl.cnf \ + -newkey rsa:2048 \ + -days 3650 \ + -subj "/CN=MyTestCA/" \ + -out cacert.pem \ + -outform PEM + +openssl x509 \ + -in cacert.pem \ + -out cacert.cer \ + -outform DER + +openssl genrsa -out $root/server/key.pem 2048 +openssl genrsa -out $root/client/key.pem 2048 + +openssl req \ + -new \ + -nodes \ + -config openssl.cnf \ + -subj "/CN=127.0.0.1/O=server/" \ + -key $root/server/key.pem \ + -out $root/server/req.pem \ + -outform PEM + +openssl req \ + -new \ + -nodes \ + -config openssl.cnf \ + -subj "/CN=127.0.0.1/O=client/" \ + -key $root/client/key.pem \ + -out $root/client/req.pem \ + -outform PEM + +openssl ca \ + -config openssl.cnf \ + -in $root/server/req.pem \ + -out $root/server/cert.pem \ + -notext \ + -batch \ + -extensions server_ca_extensions + +openssl ca \ + -config openssl.cnf \ + -in $root/client/req.pem \ + -out $root/client/cert.pem \ + -notext \ + -batch \ + -extensions client_ca_extensions + +cat <<-END +const caCert = \` +`cat $root/ca/cacert.pem` +\` + +const serverCert = \` +`cat $root/server/cert.pem` +\` + +const serverKey = \` +`cat $root/server/key.pem` +\` + +const clientCert = \` +`cat $root/client/cert.pem` +\` + +const clientKey = \` +`cat $root/client/key.pem` +\` +END diff --git a/vendor/github.com/rabbitmq/amqp091-go/channel.go b/vendor/github.com/rabbitmq/amqp091-go/channel.go new file mode 100644 index 0000000000..ae6f2d1ad1 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/channel.go @@ -0,0 +1,1709 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All Rights Reserved. +// Copyright (c) 2012-2021, Sean Treadway, SoundCloud Ltd. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package amqp091 + +import ( + "context" + "errors" + "reflect" + "sync" + "sync/atomic" +) + +// 0 1 3 7 size+7 size+8 +// +------+---------+-------------+ +------------+ +-----------+ +// | type | channel | size | | payload | | frame-end | +// +------+---------+-------------+ +------------+ +-----------+ +// +// octet short long size octets octet +const frameHeaderSize = 1 + 2 + 4 + 1 + +/* +Channel represents an AMQP channel. Used as a context for valid message +exchange. Errors on methods with this Channel as a receiver means this channel +should be discarded and a new channel established. +*/ +type Channel struct { + destructor sync.Once + m sync.Mutex // struct field mutex + confirmM sync.Mutex // publisher confirms state mutex + notifyM sync.RWMutex + + connection *Connection + + rpc chan message + consumers *consumers + + id uint16 + + // closed is set to 1 when the channel has been closed - see Channel.send() + closed int32 + + // true when we will never notify again + noNotify bool + + // Channel and Connection exceptions will be broadcast on these listeners. + closes []chan *Error + + // Listeners for active=true flow control. When true is sent to a listener, + // publishing should pause until false is sent to listeners. + flows []chan bool + + // Listeners for returned publishings for unroutable messages on mandatory + // publishings or undeliverable messages on immediate publishings. + returns []chan Return + + // Listeners for when the server notifies the client that + // a consumer has been cancelled. + cancels []chan string + + // Allocated when in confirm mode in order to track publish counter and order confirms + confirms *confirms + confirming bool + + // Selects on any errors from shutdown during RPC + errors chan *Error + + // State machine that manages frame order, must only be mutated by the connection + recv func(*Channel, frame) + + // Current state for frame re-assembly, only mutated from recv + message messageWithContent + header *headerFrame + body []byte +} + +// Constructs a new channel with the given framing rules +func newChannel(c *Connection, id uint16) *Channel { + return &Channel{ + connection: c, + id: id, + rpc: make(chan message), + consumers: makeConsumers(), + confirms: newConfirms(), + recv: (*Channel).recvMethod, + errors: make(chan *Error, 1), + } +} + +// Signal that from now on, Channel.send() should call Channel.sendClosed() +func (ch *Channel) setClosed() { + atomic.StoreInt32(&ch.closed, 1) +} + +// shutdown is called by Connection after the channel has been removed from the +// connection registry. +func (ch *Channel) shutdown(e *Error) { + ch.setClosed() + + ch.destructor.Do(func() { + ch.m.Lock() + defer ch.m.Unlock() + + // Grab an exclusive lock for the notify channels + ch.notifyM.Lock() + defer ch.notifyM.Unlock() + + // Broadcast abnormal shutdown + if e != nil { + for _, c := range ch.closes { + c <- e + } + // Notify RPC if we're selecting + ch.errors <- e + } + + ch.consumers.close() + + for _, c := range ch.closes { + close(c) + } + + for _, c := range ch.flows { + close(c) + } + + for _, c := range ch.returns { + close(c) + } + + for _, c := range ch.cancels { + close(c) + } + + // Set the slices to nil to prevent the dispatch() range from sending on + // the now closed channels after we release the notifyM mutex + ch.flows = nil + ch.closes = nil + ch.returns = nil + ch.cancels = nil + + if ch.confirms != nil { + ch.confirms.Close() + } + + close(ch.errors) + ch.noNotify = true + }) +} + +// send calls Channel.sendOpen() during normal operation. +// +// After the channel has been closed, send calls Channel.sendClosed(), ensuring +// only 'channel.close' is sent to the server. +func (ch *Channel) send(msg message) (err error) { + // If the channel is closed, use Channel.sendClosed() + if ch.IsClosed() { + return ch.sendClosed(msg) + } + + return ch.sendOpen(msg) +} + +func (ch *Channel) open() error { + return ch.call(&channelOpen{}, &channelOpenOk{}) +} + +// Performs a request/response call for when the message is not NoWait and is +// specified as Synchronous. +func (ch *Channel) call(req message, res ...message) error { + if err := ch.send(req); err != nil { + return err + } + + if req.wait() { + select { + case e, ok := <-ch.errors: + if ok { + return e + } + return ErrClosed + + case msg := <-ch.rpc: + if msg != nil { + for _, try := range res { + if reflect.TypeOf(msg) == reflect.TypeOf(try) { + // *res = *msg + vres := reflect.ValueOf(try).Elem() + vmsg := reflect.ValueOf(msg).Elem() + vres.Set(vmsg) + return nil + } + } + return ErrCommandInvalid + } + // RPC channel has been closed without an error, likely due to a hard + // error on the Connection. This indicates we have already been + // shutdown and if were waiting, will have returned from the errors chan. + return ErrClosed + } + } + + return nil +} + +func (ch *Channel) sendClosed(msg message) (err error) { + // After a 'channel.close' is sent or received the only valid response is + // channel.close-ok + if _, ok := msg.(*channelCloseOk); ok { + return ch.connection.send(&methodFrame{ + ChannelId: ch.id, + Method: msg, + }) + } + + return ErrClosed +} + +func (ch *Channel) sendOpen(msg message) (err error) { + if content, ok := msg.(messageWithContent); ok { + props, body := content.getContent() + class, _ := content.id() + + // catch client max frame size==0 and server max frame size==0 + // set size to length of what we're trying to publish + var size int + if ch.connection.Config.FrameSize > 0 { + size = ch.connection.Config.FrameSize - frameHeaderSize + } else { + size = len(body) + } + + // If the channel is closed, use Channel.sendClosed() + if ch.IsClosed() { + return ch.sendClosed(msg) + } + + // Flush the buffer only after all the Frames that comprise the Message + // have been written to maximise benefits of using a buffered writer. + defer func() { + if endError := ch.connection.endSendUnflushed(); endError != nil { + if err == nil { + err = endError + } + } + }() + + // We use sendUnflushed() in this method as sending the message requires + // sending multiple Frames (methodFrame, headerFrame, N x bodyFrame). + // Flushing after each Frame is inefficient, as it negates much of the + // benefit of using a buffered writer and results in more syscalls than + // necessary. Flushing buffers after every frame can have a significant + // performance impact when sending (e.g. basicPublish) small messages, + // so sendUnflushed() performs an *Unflushed* write, but is otherwise + // equivalent to the send() method. We later use the separate flush + // method to explicitly flush the buffer after all Frames are written. + if err = ch.connection.sendUnflushed(&methodFrame{ + ChannelId: ch.id, + Method: content, + }); err != nil { + return + } + + if err = ch.connection.sendUnflushed(&headerFrame{ + ChannelId: ch.id, + ClassId: class, + Size: uint64(len(body)), + Properties: props, + }); err != nil { + return + } + + // chunk body into size (max frame size - frame header size) + for i, j := 0, size; i < len(body); i, j = j, j+size { + if j > len(body) { + j = len(body) + } + + if err = ch.connection.sendUnflushed(&bodyFrame{ + ChannelId: ch.id, + Body: body[i:j], + }); err != nil { + return + } + } + } else { + // If the channel is closed, use Channel.sendClosed() + if ch.IsClosed() { + return ch.sendClosed(msg) + } + + err = ch.connection.send(&methodFrame{ + ChannelId: ch.id, + Method: msg, + }) + } + + return +} + +// Eventually called via the state machine from the connection's reader +// goroutine, so assumes serialized access. +func (ch *Channel) dispatch(msg message) { + switch m := msg.(type) { + case *channelClose: + // Note: channel state is set to closed immedately after the message is + // decoded by the Connection + + // lock before sending connection.close-ok + // to avoid unexpected interleaving with basic.publish frames if + // publishing is happening concurrently + ch.m.Lock() + if err := ch.send(&channelCloseOk{}); err != nil { + Logger.Printf("error sending channelCloseOk, channel id: %d error: %+v", ch.id, err) + } + ch.m.Unlock() + ch.connection.closeChannel(ch, newError(m.ReplyCode, m.ReplyText)) + + case *channelFlow: + ch.notifyM.RLock() + for _, c := range ch.flows { + c <- m.Active + } + ch.notifyM.RUnlock() + if err := ch.send(&channelFlowOk{Active: m.Active}); err != nil { + Logger.Printf("error sending channelFlowOk, channel id: %d error: %+v", ch.id, err) + } + + case *basicCancel: + ch.notifyM.RLock() + for _, c := range ch.cancels { + c <- m.ConsumerTag + } + ch.notifyM.RUnlock() + ch.consumers.cancel(m.ConsumerTag) + + case *basicReturn: + ret := newReturn(*m) + ch.notifyM.RLock() + for _, c := range ch.returns { + c <- *ret + } + ch.notifyM.RUnlock() + + case *basicAck: + if ch.confirming { + if m.Multiple { + ch.confirms.Multiple(Confirmation{m.DeliveryTag, true}) + } else { + ch.confirms.One(Confirmation{m.DeliveryTag, true}) + } + } + + case *basicNack: + if ch.confirming { + if m.Multiple { + ch.confirms.Multiple(Confirmation{m.DeliveryTag, false}) + } else { + ch.confirms.One(Confirmation{m.DeliveryTag, false}) + } + } + + case *basicDeliver: + ch.consumers.send(m.ConsumerTag, newDelivery(ch, m)) + // TODO log failed consumer and close channel, this can happen when + // deliveries are in flight and a no-wait cancel has happened + + default: + ch.rpc <- msg + } +} + +func (ch *Channel) transition(f func(*Channel, frame)) { + ch.recv = f +} + +func (ch *Channel) recvMethod(f frame) { + switch frame := f.(type) { + case *methodFrame: + if msg, ok := frame.Method.(messageWithContent); ok { + ch.body = make([]byte, 0) + ch.message = msg + ch.transition((*Channel).recvHeader) + return + } + + ch.dispatch(frame.Method) // termination state + ch.transition((*Channel).recvMethod) + + case *headerFrame: + // drop + ch.transition((*Channel).recvMethod) + + case *bodyFrame: + // drop + ch.transition((*Channel).recvMethod) + + default: + panic("unexpected frame type") + } +} + +func (ch *Channel) recvHeader(f frame) { + switch frame := f.(type) { + case *methodFrame: + // interrupt content and handle method + ch.recvMethod(f) + + case *headerFrame: + // start collecting if we expect body frames + ch.header = frame + + if frame.Size == 0 { + ch.message.setContent(ch.header.Properties, ch.body) + ch.dispatch(ch.message) // termination state + ch.transition((*Channel).recvMethod) + return + } + ch.transition((*Channel).recvContent) + + case *bodyFrame: + // drop and reset + ch.transition((*Channel).recvMethod) + + default: + panic("unexpected frame type") + } +} + +// state after method + header and before the length +// defined by the header has been reached +func (ch *Channel) recvContent(f frame) { + switch frame := f.(type) { + case *methodFrame: + // interrupt content and handle method + ch.recvMethod(f) + + case *headerFrame: + // drop and reset + ch.transition((*Channel).recvMethod) + + case *bodyFrame: + if cap(ch.body) == 0 { + ch.body = make([]byte, 0, ch.header.Size) + } + ch.body = append(ch.body, frame.Body...) + + if uint64(len(ch.body)) >= ch.header.Size { + ch.message.setContent(ch.header.Properties, ch.body) + ch.dispatch(ch.message) // termination state + ch.transition((*Channel).recvMethod) + return + } + + ch.transition((*Channel).recvContent) + + default: + panic("unexpected frame type") + } +} + +/* +Close initiate a clean channel closure by sending a close message with the error +code set to '200'. + +It is safe to call this method multiple times. +*/ +func (ch *Channel) Close() error { + defer ch.connection.closeChannel(ch, nil) + return ch.call( + &channelClose{ReplyCode: replySuccess}, + &channelCloseOk{}, + ) +} + +// IsClosed returns true if the channel is marked as closed, otherwise false +// is returned. +func (ch *Channel) IsClosed() bool { + return atomic.LoadInt32(&ch.closed) == 1 +} + +/* +NotifyClose registers a listener for when the server sends a channel or +connection exception in the form of a Connection.Close or Channel.Close method. +Connection exceptions will be broadcast to all open channels and all channels +will be closed, where channel exceptions will only be broadcast to listeners to +this channel. + +The chan provided will be closed when the Channel is closed and on a +graceful close, no error will be sent. + +In case of a non graceful close the error will be notified synchronously by the library +so that it will be necessary to consume the Channel from the caller in order to avoid deadlocks +*/ +func (ch *Channel) NotifyClose(c chan *Error) chan *Error { + ch.notifyM.Lock() + defer ch.notifyM.Unlock() + + if ch.noNotify { + close(c) + } else { + ch.closes = append(ch.closes, c) + } + + return c +} + +/* +NotifyFlow registers a listener for basic.flow methods sent by the server. +When `false` is sent on one of the listener channels, all publishers should +pause until a `true` is sent. + +The server may ask the producer to pause or restart the flow of Publishings +sent by on a channel. This is a simple flow-control mechanism that a server can +use to avoid overflowing its queues or otherwise finding itself receiving more +messages than it can process. Note that this method is not intended for window +control. It does not affect contents returned by basic.get-ok methods. + +When a new channel is opened, it is active (flow is active). Some +applications assume that channels are inactive until started. To emulate +this behavior a client MAY open the channel, then pause it. + +Publishers should respond to a flow messages as rapidly as possible and the +server may disconnect over producing channels that do not respect these +messages. + +basic.flow-ok methods will always be returned to the server regardless of +the number of listeners there are. + +To control the flow of deliveries from the server, use the Channel.Flow() +method instead. + +Note: RabbitMQ will rather use TCP pushback on the network connection instead +of sending basic.flow. This means that if a single channel is producing too +much on the same connection, all channels using that connection will suffer, +including acknowledgments from deliveries. Use different Connections if you +desire to interleave consumers and producers in the same process to avoid your +basic.ack messages from getting rate limited with your basic.publish messages. +*/ +func (ch *Channel) NotifyFlow(c chan bool) chan bool { + ch.notifyM.Lock() + defer ch.notifyM.Unlock() + + if ch.noNotify { + close(c) + } else { + ch.flows = append(ch.flows, c) + } + + return c +} + +/* +NotifyReturn registers a listener for basic.return methods. These can be sent +from the server when a publish is undeliverable either from the mandatory or +immediate flags. + +A return struct has a copy of the Publishing along with some error +information about why the publishing failed. +*/ +func (ch *Channel) NotifyReturn(c chan Return) chan Return { + ch.notifyM.Lock() + defer ch.notifyM.Unlock() + + if ch.noNotify { + close(c) + } else { + ch.returns = append(ch.returns, c) + } + + return c +} + +/* +NotifyCancel registers a listener for basic.cancel methods. These can be sent +from the server when a queue is deleted or when consuming from a mirrored queue +where the master has just failed (and was moved to another node). + +The subscription tag is returned to the listener. +*/ +func (ch *Channel) NotifyCancel(c chan string) chan string { + ch.notifyM.Lock() + defer ch.notifyM.Unlock() + + if ch.noNotify { + close(c) + } else { + ch.cancels = append(ch.cancels, c) + } + + return c +} + +/* +NotifyConfirm calls NotifyPublish and starts a goroutine sending +ordered Ack and Nack DeliveryTag to the respective channels. + +For strict ordering, use NotifyPublish instead. +*/ +func (ch *Channel) NotifyConfirm(ack, nack chan uint64) (chan uint64, chan uint64) { + confirms := ch.NotifyPublish(make(chan Confirmation, cap(ack)+cap(nack))) + + go func() { + for c := range confirms { + if c.Ack { + ack <- c.DeliveryTag + } else { + nack <- c.DeliveryTag + } + } + close(ack) + if nack != ack { + close(nack) + } + }() + + return ack, nack +} + +/* +NotifyPublish registers a listener for reliable publishing. Receives from this +chan for every publish after Channel.Confirm will be in order starting with +DeliveryTag 1. + +There will be one and only one Confirmation Publishing starting with the +delivery tag of 1 and progressing sequentially until the total number of +Publishings have been seen by the server. + +Acknowledgments will be received in the order of delivery from the +NotifyPublish channels even if the server acknowledges them out of order. + +The listener chan will be closed when the Channel is closed. + +The capacity of the chan Confirmation must be at least as large as the +number of outstanding publishings. Not having enough buffered chans will +create a deadlock if you attempt to perform other operations on the Connection +or Channel while confirms are in-flight. + +It's advisable to wait for all Confirmations to arrive before calling +Channel.Close() or Connection.Close(). + +It is also advisable for the caller to consume from the channel returned till it is closed +to avoid possible deadlocks +*/ +func (ch *Channel) NotifyPublish(confirm chan Confirmation) chan Confirmation { + ch.notifyM.Lock() + defer ch.notifyM.Unlock() + + if ch.noNotify { + close(confirm) + } else { + ch.confirms.Listen(confirm) + } + + return confirm +} + +/* +Qos controls how many messages or how many bytes the server will try to keep on +the network for consumers before receiving delivery acks. The intent of Qos is +to make sure the network buffers stay full between the server and client. + +With a prefetch count greater than zero, the server will deliver that many +messages to consumers before acknowledgments are received. The server ignores +this option when consumers are started with noAck because no acknowledgments +are expected or sent. + +With a prefetch size greater than zero, the server will try to keep at least +that many bytes of deliveries flushed to the network before receiving +acknowledgments from the consumers. This option is ignored when consumers are +started with noAck. + +When global is true, these Qos settings apply to all existing and future +consumers on all channels on the same connection. When false, the Channel.Qos +settings will apply to all existing and future consumers on this channel. + +Please see the RabbitMQ Consumer Prefetch documentation for an explanation of +how the global flag is implemented in RabbitMQ, as it differs from the +AMQP 0.9.1 specification in that global Qos settings are limited in scope to +channels, not connections (https://www.rabbitmq.com/consumer-prefetch.html). + +To get round-robin behavior between consumers consuming from the same queue on +different connections, set the prefetch count to 1, and the next available +message on the server will be delivered to the next available consumer. + +If your consumer work time is reasonably consistent and not much greater +than two times your network round trip time, you will see significant +throughput improvements starting with a prefetch count of 2 or slightly +greater as described by benchmarks on RabbitMQ. + +http://www.rabbitmq.com/blog/2012/04/25/rabbitmq-performance-measurements-part-2/ +*/ +func (ch *Channel) Qos(prefetchCount, prefetchSize int, global bool) error { + return ch.call( + &basicQos{ + PrefetchCount: uint16(prefetchCount), + PrefetchSize: uint32(prefetchSize), + Global: global, + }, + &basicQosOk{}, + ) +} + +/* +Cancel stops deliveries to the consumer chan established in Channel.Consume and +identified by consumer. + +Only use this method to cleanly stop receiving deliveries from the server and +cleanly shut down the consumer chan identified by this tag. Using this method +and waiting for remaining messages to flush from the consumer chan will ensure +all messages received on the network will be delivered to the receiver of your +consumer chan. + +Continue consuming from the chan Delivery provided by Channel.Consume until the +chan closes. + +When noWait is true, do not wait for the server to acknowledge the cancel. +Only use this when you are certain there are no deliveries in flight that +require an acknowledgment, otherwise they will arrive and be dropped in the +client without an ack, and will not be redelivered to other consumers. +*/ +func (ch *Channel) Cancel(consumer string, noWait bool) error { + req := &basicCancel{ + ConsumerTag: consumer, + NoWait: noWait, + } + res := &basicCancelOk{} + + if err := ch.call(req, res); err != nil { + return err + } + + if req.wait() { + ch.consumers.cancel(res.ConsumerTag) + } else { + // Potentially could drop deliveries in flight + ch.consumers.cancel(consumer) + } + + return nil +} + +/* +QueueDeclare declares a queue to hold messages and deliver to consumers. +Declaring creates a queue if it doesn't already exist, or ensures that an +existing queue matches the same parameters. + +Every queue declared gets a default binding to the empty exchange "" which has +the type "direct" with the routing key matching the queue's name. With this +default binding, it is possible to publish messages that route directly to +this queue by publishing to "" with the routing key of the queue name. + + QueueDeclare("alerts", true, false, false, false, nil) + Publish("", "alerts", false, false, Publishing{Body: []byte("...")}) + + Delivery Exchange Key Queue + ----------------------------------------------- + key: alerts -> "" -> alerts -> alerts + +The queue name may be empty, in which case the server will generate a unique name +which will be returned in the Name field of Queue struct. + +Durable and Non-Auto-Deleted queues will survive server restarts and remain +when there are no remaining consumers or bindings. Persistent publishings will +be restored in this queue on server restart. These queues are only able to be +bound to durable exchanges. + +Non-Durable and Auto-Deleted queues will not be redeclared on server restart +and will be deleted by the server after a short time when the last consumer is +canceled or the last consumer's channel is closed. Queues with this lifetime +can also be deleted normally with QueueDelete. These durable queues can only +be bound to non-durable exchanges. + +Non-Durable and Non-Auto-Deleted queues will remain declared as long as the +server is running regardless of how many consumers. This lifetime is useful +for temporary topologies that may have long delays between consumer activity. +These queues can only be bound to non-durable exchanges. + +Durable and Auto-Deleted queues will be restored on server restart, but without +active consumers will not survive and be removed. This Lifetime is unlikely +to be useful. + +Exclusive queues are only accessible by the connection that declares them and +will be deleted when the connection closes. Channels on other connections +will receive an error when attempting to declare, bind, consume, purge or +delete a queue with the same name. + +When noWait is true, the queue will assume to be declared on the server. A +channel exception will arrive if the conditions are met for existing queues +or attempting to modify an existing queue from a different connection. + +When the error return value is not nil, you can assume the queue could not be +declared with these parameters, and the channel will be closed. +*/ +func (ch *Channel) QueueDeclare(name string, durable, autoDelete, exclusive, noWait bool, args Table) (Queue, error) { + if err := args.Validate(); err != nil { + return Queue{}, err + } + + req := &queueDeclare{ + Queue: name, + Passive: false, + Durable: durable, + AutoDelete: autoDelete, + Exclusive: exclusive, + NoWait: noWait, + Arguments: args, + } + res := &queueDeclareOk{} + + if err := ch.call(req, res); err != nil { + return Queue{}, err + } + + if req.wait() { + return Queue{ + Name: res.Queue, + Messages: int(res.MessageCount), + Consumers: int(res.ConsumerCount), + }, nil + } + + return Queue{Name: name}, nil +} + +/* +QueueDeclarePassive is functionally and parametrically equivalent to +QueueDeclare, except that it sets the "passive" attribute to true. A passive +queue is assumed by RabbitMQ to already exist, and attempting to connect to a +non-existent queue will cause RabbitMQ to throw an exception. This function +can be used to test for the existence of a queue. +*/ +func (ch *Channel) QueueDeclarePassive(name string, durable, autoDelete, exclusive, noWait bool, args Table) (Queue, error) { + if err := args.Validate(); err != nil { + return Queue{}, err + } + + req := &queueDeclare{ + Queue: name, + Passive: true, + Durable: durable, + AutoDelete: autoDelete, + Exclusive: exclusive, + NoWait: noWait, + Arguments: args, + } + res := &queueDeclareOk{} + + if err := ch.call(req, res); err != nil { + return Queue{}, err + } + + if req.wait() { + return Queue{ + Name: res.Queue, + Messages: int(res.MessageCount), + Consumers: int(res.ConsumerCount), + }, nil + } + + return Queue{Name: name}, nil +} + +/* +QueueInspect passively declares a queue by name to inspect the current message +count and consumer count. + +Use this method to check how many messages ready for delivery reside in the queue, +how many consumers are receiving deliveries, and whether a queue by this +name already exists. + +If the queue by this name exists, use Channel.QueueDeclare check if it is +declared with specific parameters. + +If a queue by this name does not exist, an error will be returned and the +channel will be closed. + +Deprecated: Use QueueDeclare with "Passive: true" instead. +*/ +func (ch *Channel) QueueInspect(name string) (Queue, error) { + req := &queueDeclare{ + Queue: name, + Passive: true, + } + res := &queueDeclareOk{} + + err := ch.call(req, res) + + state := Queue{ + Name: name, + Messages: int(res.MessageCount), + Consumers: int(res.ConsumerCount), + } + + return state, err +} + +/* +QueueBind binds an exchange to a queue so that publishings to the exchange will +be routed to the queue when the publishing routing key matches the binding +routing key. + + QueueBind("pagers", "alert", "log", false, nil) + QueueBind("emails", "info", "log", false, nil) + + Delivery Exchange Key Queue + ----------------------------------------------- + key: alert --> log ----> alert --> pagers + key: info ---> log ----> info ---> emails + key: debug --> log (none) (dropped) + +If a binding with the same key and arguments already exists between the +exchange and queue, the attempt to rebind will be ignored and the existing +binding will be retained. + +In the case that multiple bindings may cause the message to be routed to the +same queue, the server will only route the publishing once. This is possible +with topic exchanges. + + QueueBind("pagers", "alert", "amq.topic", false, nil) + QueueBind("emails", "info", "amq.topic", false, nil) + QueueBind("emails", "#", "amq.topic", false, nil) // match everything + + Delivery Exchange Key Queue + ----------------------------------------------- + key: alert --> amq.topic ----> alert --> pagers + key: info ---> amq.topic ----> # ------> emails + \---> info ---/ + key: debug --> amq.topic ----> # ------> emails + +It is only possible to bind a durable queue to a durable exchange regardless of +whether the queue or exchange is auto-deleted. Bindings between durable queues +and exchanges will also be restored on server restart. + +If the binding could not complete, an error will be returned and the channel +will be closed. + +When noWait is false and the queue could not be bound, the channel will be +closed with an error. +*/ +func (ch *Channel) QueueBind(name, key, exchange string, noWait bool, args Table) error { + if err := args.Validate(); err != nil { + return err + } + + return ch.call( + &queueBind{ + Queue: name, + Exchange: exchange, + RoutingKey: key, + NoWait: noWait, + Arguments: args, + }, + &queueBindOk{}, + ) +} + +/* +QueueUnbind removes a binding between an exchange and queue matching the key and +arguments. + +It is possible to send and empty string for the exchange name which means to +unbind the queue from the default exchange. +*/ +func (ch *Channel) QueueUnbind(name, key, exchange string, args Table) error { + if err := args.Validate(); err != nil { + return err + } + + return ch.call( + &queueUnbind{ + Queue: name, + Exchange: exchange, + RoutingKey: key, + Arguments: args, + }, + &queueUnbindOk{}, + ) +} + +/* +QueuePurge removes all messages from the named queue which are not waiting to +be acknowledged. Messages that have been delivered but have not yet been +acknowledged will not be removed. + +When successful, returns the number of messages purged. + +If noWait is true, do not wait for the server response and the number of +messages purged will not be meaningful. +*/ +func (ch *Channel) QueuePurge(name string, noWait bool) (int, error) { + req := &queuePurge{ + Queue: name, + NoWait: noWait, + } + res := &queuePurgeOk{} + + err := ch.call(req, res) + + return int(res.MessageCount), err +} + +/* +QueueDelete removes the queue from the server including all bindings then +purges the messages based on server configuration, returning the number of +messages purged. + +When ifUnused is true, the queue will not be deleted if there are any +consumers on the queue. If there are consumers, an error will be returned and +the channel will be closed. + +When ifEmpty is true, the queue will not be deleted if there are any messages +remaining on the queue. If there are messages, an error will be returned and +the channel will be closed. + +When noWait is true, the queue will be deleted without waiting for a response +from the server. The purged message count will not be meaningful. If the queue +could not be deleted, a channel exception will be raised and the channel will +be closed. +*/ +func (ch *Channel) QueueDelete(name string, ifUnused, ifEmpty, noWait bool) (int, error) { + req := &queueDelete{ + Queue: name, + IfUnused: ifUnused, + IfEmpty: ifEmpty, + NoWait: noWait, + } + res := &queueDeleteOk{} + + err := ch.call(req, res) + + return int(res.MessageCount), err +} + +/* +Consume immediately starts delivering queued messages. + +Begin receiving on the returned chan Delivery before any other operation on the +Connection or Channel. + +Continues deliveries to the returned chan Delivery until Channel.Cancel, +Connection.Close, Channel.Close, or an AMQP exception occurs. Consumers must +range over the chan to ensure all deliveries are received. Unreceived +deliveries will block all methods on the same connection. + +All deliveries in AMQP must be acknowledged. It is expected of the consumer to +call Delivery.Ack after it has successfully processed the delivery. If the +consumer is cancelled or the channel or connection is closed any unacknowledged +deliveries will be requeued at the end of the same queue. + +The consumer is identified by a string that is unique and scoped for all +consumers on this channel. If you wish to eventually cancel the consumer, use +the same non-empty identifier in Channel.Cancel. An empty string will cause +the library to generate a unique identity. The consumer identity will be +included in every Delivery in the ConsumerTag field + +When autoAck (also known as noAck) is true, the server will acknowledge +deliveries to this consumer prior to writing the delivery to the network. When +autoAck is true, the consumer should not call Delivery.Ack. Automatically +acknowledging deliveries means that some deliveries may get lost if the +consumer is unable to process them after the server delivers them. +See http://www.rabbitmq.com/confirms.html for more details. + +When exclusive is true, the server will ensure that this is the sole consumer +from this queue. When exclusive is false, the server will fairly distribute +deliveries across multiple consumers. + +The noLocal flag is not supported by RabbitMQ. + +It's advisable to use separate connections for +Channel.Publish and Channel.Consume so not to have TCP pushback on publishing +affect the ability to consume messages, so this parameter is here mostly for +completeness. + +When noWait is true, do not wait for the server to confirm the request and +immediately begin deliveries. If it is not possible to consume, a channel +exception will be raised and the channel will be closed. + +Optional arguments can be provided that have specific semantics for the queue +or server. + +Inflight messages, limited by Channel.Qos will be buffered until received from +the returned chan. + +When the Channel or Connection is closed, all buffered and inflight messages will +be dropped. + +When the consumer tag is cancelled, all inflight messages will be delivered until +the returned chan is closed. +*/ +func (ch *Channel) Consume(queue, consumer string, autoAck, exclusive, noLocal, noWait bool, args Table) (<-chan Delivery, error) { + // When we return from ch.call, there may be a delivery already for the + // consumer that hasn't been added to the consumer hash yet. Because of + // this, we never rely on the server picking a consumer tag for us. + + if err := args.Validate(); err != nil { + return nil, err + } + + if consumer == "" { + consumer = uniqueConsumerTag() + } + + req := &basicConsume{ + Queue: queue, + ConsumerTag: consumer, + NoLocal: noLocal, + NoAck: autoAck, + Exclusive: exclusive, + NoWait: noWait, + Arguments: args, + } + res := &basicConsumeOk{} + + deliveries := make(chan Delivery) + + ch.consumers.add(consumer, deliveries) + + if err := ch.call(req, res); err != nil { + ch.consumers.cancel(consumer) + return nil, err + } + + return deliveries, nil +} + +/* +ExchangeDeclare declares an exchange on the server. If the exchange does not +already exist, the server will create it. If the exchange exists, the server +verifies that it is of the provided type, durability and auto-delete flags. + +Errors returned from this method will close the channel. + +Exchange names starting with "amq." are reserved for pre-declared and +standardized exchanges. The client MAY declare an exchange starting with +"amq." if the passive option is set, or the exchange already exists. Names can +consist of a non-empty sequence of letters, digits, hyphen, underscore, +period, or colon. + +Each exchange belongs to one of a set of exchange kinds/types implemented by +the server. The exchange types define the functionality of the exchange - i.e. +how messages are routed through it. Once an exchange is declared, its type +cannot be changed. The common types are "direct", "fanout", "topic" and +"headers". + +Durable and Non-Auto-Deleted exchanges will survive server restarts and remain +declared when there are no remaining bindings. This is the best lifetime for +long-lived exchange configurations like stable routes and default exchanges. + +Non-Durable and Auto-Deleted exchanges will be deleted when there are no +remaining bindings and not restored on server restart. This lifetime is +useful for temporary topologies that should not pollute the virtual host on +failure or after the consumers have completed. + +Non-Durable and Non-Auto-deleted exchanges will remain as long as the server is +running including when there are no remaining bindings. This is useful for +temporary topologies that may have long delays between bindings. + +Durable and Auto-Deleted exchanges will survive server restarts and will be +removed before and after server restarts when there are no remaining bindings. +These exchanges are useful for robust temporary topologies or when you require +binding durable queues to auto-deleted exchanges. + +Note: RabbitMQ declares the default exchange types like 'amq.fanout' as +durable, so queues that bind to these pre-declared exchanges must also be +durable. + +Exchanges declared as `internal` do not accept accept publishings. Internal +exchanges are useful when you wish to implement inter-exchange topologies +that should not be exposed to users of the broker. + +When noWait is true, declare without waiting for a confirmation from the server. +The channel may be closed as a result of an error. Add a NotifyClose listener +to respond to any exceptions. + +Optional amqp.Table of arguments that are specific to the server's implementation of +the exchange can be sent for exchange types that require extra parameters. +*/ +func (ch *Channel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args Table) error { + if err := args.Validate(); err != nil { + return err + } + + return ch.call( + &exchangeDeclare{ + Exchange: name, + Type: kind, + Passive: false, + Durable: durable, + AutoDelete: autoDelete, + Internal: internal, + NoWait: noWait, + Arguments: args, + }, + &exchangeDeclareOk{}, + ) +} + +/* +ExchangeDeclarePassive is functionally and parametrically equivalent to +ExchangeDeclare, except that it sets the "passive" attribute to true. A passive +exchange is assumed by RabbitMQ to already exist, and attempting to connect to a +non-existent exchange will cause RabbitMQ to throw an exception. This function +can be used to detect the existence of an exchange. +*/ +func (ch *Channel) ExchangeDeclarePassive(name, kind string, durable, autoDelete, internal, noWait bool, args Table) error { + if err := args.Validate(); err != nil { + return err + } + + return ch.call( + &exchangeDeclare{ + Exchange: name, + Type: kind, + Passive: true, + Durable: durable, + AutoDelete: autoDelete, + Internal: internal, + NoWait: noWait, + Arguments: args, + }, + &exchangeDeclareOk{}, + ) +} + +/* +ExchangeDelete removes the named exchange from the server. When an exchange is +deleted all queue bindings on the exchange are also deleted. If this exchange +does not exist, the channel will be closed with an error. + +When ifUnused is true, the server will only delete the exchange if it has no queue +bindings. If the exchange has queue bindings the server does not delete it +but close the channel with an exception instead. Set this to true if you are +not the sole owner of the exchange. + +When noWait is true, do not wait for a server confirmation that the exchange has +been deleted. Failing to delete the channel could close the channel. Add a +NotifyClose listener to respond to these channel exceptions. +*/ +func (ch *Channel) ExchangeDelete(name string, ifUnused, noWait bool) error { + return ch.call( + &exchangeDelete{ + Exchange: name, + IfUnused: ifUnused, + NoWait: noWait, + }, + &exchangeDeleteOk{}, + ) +} + +/* +ExchangeBind binds an exchange to another exchange to create inter-exchange +routing topologies on the server. This can decouple the private topology and +routing exchanges from exchanges intended solely for publishing endpoints. + +Binding two exchanges with identical arguments will not create duplicate +bindings. + +Binding one exchange to another with multiple bindings will only deliver a +message once. For example if you bind your exchange to `amq.fanout` with two +different binding keys, only a single message will be delivered to your +exchange even though multiple bindings will match. + +Given a message delivered to the source exchange, the message will be forwarded +to the destination exchange when the routing key is matched. + + ExchangeBind("sell", "MSFT", "trade", false, nil) + ExchangeBind("buy", "AAPL", "trade", false, nil) + + Delivery Source Key Destination + example exchange exchange + ----------------------------------------------- + key: AAPL --> trade ----> MSFT sell + \---> AAPL --> buy + +When noWait is true, do not wait for the server to confirm the binding. If any +error occurs the channel will be closed. Add a listener to NotifyClose to +handle these errors. + +Optional arguments specific to the exchanges bound can also be specified. +*/ +func (ch *Channel) ExchangeBind(destination, key, source string, noWait bool, args Table) error { + if err := args.Validate(); err != nil { + return err + } + + return ch.call( + &exchangeBind{ + Destination: destination, + Source: source, + RoutingKey: key, + NoWait: noWait, + Arguments: args, + }, + &exchangeBindOk{}, + ) +} + +/* +ExchangeUnbind unbinds the destination exchange from the source exchange on the +server by removing the routing key between them. This is the inverse of +ExchangeBind. If the binding does not currently exist, an error will be +returned. + +When noWait is true, do not wait for the server to confirm the deletion of the +binding. If any error occurs the channel will be closed. Add a listener to +NotifyClose to handle these errors. + +Optional arguments that are specific to the type of exchanges bound can also be +provided. These must match the same arguments specified in ExchangeBind to +identify the binding. +*/ +func (ch *Channel) ExchangeUnbind(destination, key, source string, noWait bool, args Table) error { + if err := args.Validate(); err != nil { + return err + } + + return ch.call( + &exchangeUnbind{ + Destination: destination, + Source: source, + RoutingKey: key, + NoWait: noWait, + Arguments: args, + }, + &exchangeUnbindOk{}, + ) +} + +/* +Publish sends a Publishing from the client to an exchange on the server. + +When you want a single message to be delivered to a single queue, you can +publish to the default exchange with the routingKey of the queue name. This is +because every declared queue gets an implicit route to the default exchange. + +Since publishings are asynchronous, any undeliverable message will get returned +by the server. Add a listener with Channel.NotifyReturn to handle any +undeliverable message when calling publish with either the mandatory or +immediate parameters as true. + +Publishings can be undeliverable when the mandatory flag is true and no queue is +bound that matches the routing key, or when the immediate flag is true and no +consumer on the matched queue is ready to accept the delivery. + +This can return an error when the channel, connection or socket is closed. The +error or lack of an error does not indicate whether the server has received this +publishing. + +It is possible for publishing to not reach the broker if the underlying socket +is shut down without pending publishing packets being flushed from the kernel +buffers. The easy way of making it probable that all publishings reach the +server is to always call Connection.Close before terminating your publishing +application. The way to ensure that all publishings reach the server is to add +a listener to Channel.NotifyPublish and put the channel in confirm mode with +Channel.Confirm. Publishing delivery tags and their corresponding +confirmations start at 1. Exit when all publishings are confirmed. + +When Publish does not return an error and the channel is in confirm mode, the +internal counter for DeliveryTags with the first confirmation starts at 1. + +Deprecated: Use PublishWithContext instead. +*/ +func (ch *Channel) Publish(exchange, key string, mandatory, immediate bool, msg Publishing) error { + _, err := ch.PublishWithDeferredConfirmWithContext(context.Background(), exchange, key, mandatory, immediate, msg) + return err +} + +/* +PublishWithContext sends a Publishing from the client to an exchange on the server. + +When you want a single message to be delivered to a single queue, you can +publish to the default exchange with the routingKey of the queue name. This is +because every declared queue gets an implicit route to the default exchange. + +Since publishings are asynchronous, any undeliverable message will get returned +by the server. Add a listener with Channel.NotifyReturn to handle any +undeliverable message when calling publish with either the mandatory or +immediate parameters as true. + +Publishings can be undeliverable when the mandatory flag is true and no queue is +bound that matches the routing key, or when the immediate flag is true and no +consumer on the matched queue is ready to accept the delivery. + +This can return an error when the channel, connection or socket is closed. The +error or lack of an error does not indicate whether the server has received this +publishing. + +It is possible for publishing to not reach the broker if the underlying socket +is shut down without pending publishing packets being flushed from the kernel +buffers. The easy way of making it probable that all publishings reach the +server is to always call Connection.Close before terminating your publishing +application. The way to ensure that all publishings reach the server is to add +a listener to Channel.NotifyPublish and put the channel in confirm mode with +Channel.Confirm. Publishing delivery tags and their corresponding +confirmations start at 1. Exit when all publishings are confirmed. + +When Publish does not return an error and the channel is in confirm mode, the +internal counter for DeliveryTags with the first confirmation starts at 1. +*/ +func (ch *Channel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) error { + _, err := ch.PublishWithDeferredConfirmWithContext(ctx, exchange, key, mandatory, immediate, msg) + return err +} + +/* +PublishWithDeferredConfirm behaves identically to Publish but additionally returns a +DeferredConfirmation, allowing the caller to wait on the publisher confirmation +for this message. If the channel has not been put into confirm mode, +the DeferredConfirmation will be nil. + +Deprecated: Use PublishWithDeferredConfirmWithContext instead. +*/ +func (ch *Channel) PublishWithDeferredConfirm(exchange, key string, mandatory, immediate bool, msg Publishing) (*DeferredConfirmation, error) { + return ch.PublishWithDeferredConfirmWithContext(context.Background(), exchange, key, mandatory, immediate, msg) +} + +/* +PublishWithDeferredConfirmWithContext behaves identically to Publish but additionally returns a +DeferredConfirmation, allowing the caller to wait on the publisher confirmation +for this message. If the channel has not been put into confirm mode, +the DeferredConfirmation will be nil. +*/ +func (ch *Channel) PublishWithDeferredConfirmWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg Publishing) (*DeferredConfirmation, error) { + if ctx == nil { + return nil, errors.New("amqp091-go: nil Context") + } + + if err := msg.Headers.Validate(); err != nil { + return nil, err + } + + ch.m.Lock() + defer ch.m.Unlock() + + var dc *DeferredConfirmation + if ch.confirming { + dc = ch.confirms.publish() + } + + if err := ch.send(&basicPublish{ + Exchange: exchange, + RoutingKey: key, + Mandatory: mandatory, + Immediate: immediate, + Body: msg.Body, + Properties: properties{ + Headers: msg.Headers, + ContentType: msg.ContentType, + ContentEncoding: msg.ContentEncoding, + DeliveryMode: msg.DeliveryMode, + Priority: msg.Priority, + CorrelationId: msg.CorrelationId, + ReplyTo: msg.ReplyTo, + Expiration: msg.Expiration, + MessageId: msg.MessageId, + Timestamp: msg.Timestamp, + Type: msg.Type, + UserId: msg.UserId, + AppId: msg.AppId, + }, + }); err != nil { + if ch.confirming { + ch.confirms.unpublish() + } + return nil, err + } + + return dc, nil +} + +/* +Get synchronously receives a single Delivery from the head of a queue from the +server to the client. In almost all cases, using Channel.Consume will be +preferred. + +If there was a delivery waiting on the queue and that delivery was received, the +second return value will be true. If there was no delivery waiting or an error +occurred, the ok bool will be false. + +All deliveries must be acknowledged including those from Channel.Get. Call +Delivery.Ack on the returned delivery when you have fully processed this +delivery. + +When autoAck is true, the server will automatically acknowledge this message so +you don't have to. But if you are unable to fully process this message before +the channel or connection is closed, the message will not get requeued. +*/ +func (ch *Channel) Get(queue string, autoAck bool) (msg Delivery, ok bool, err error) { + req := &basicGet{Queue: queue, NoAck: autoAck} + res := &basicGetOk{} + empty := &basicGetEmpty{} + + if err := ch.call(req, res, empty); err != nil { + return Delivery{}, false, err + } + + if res.DeliveryTag > 0 { + return *(newDelivery(ch, res)), true, nil + } + + return Delivery{}, false, nil +} + +/* +Tx puts the channel into transaction mode on the server. All publishings and +acknowledgments following this method will be atomically committed or rolled +back for a single queue. Call either Channel.TxCommit or Channel.TxRollback to +leave a this transaction and immediately start a new transaction. + +The atomicity across multiple queues is not defined as queue declarations and +bindings are not included in the transaction. + +The behavior of publishings that are delivered as mandatory or immediate while +the channel is in a transaction is not defined. + +Once a channel has been put into transaction mode, it cannot be taken out of +transaction mode. Use a different channel for non-transactional semantics. +*/ +func (ch *Channel) Tx() error { + return ch.call( + &txSelect{}, + &txSelectOk{}, + ) +} + +/* +TxCommit atomically commits all publishings and acknowledgments for a single +queue and immediately start a new transaction. + +Calling this method without having called Channel.Tx is an error. +*/ +func (ch *Channel) TxCommit() error { + return ch.call( + &txCommit{}, + &txCommitOk{}, + ) +} + +/* +TxRollback atomically rolls back all publishings and acknowledgments for a +single queue and immediately start a new transaction. + +Calling this method without having called Channel.Tx is an error. +*/ +func (ch *Channel) TxRollback() error { + return ch.call( + &txRollback{}, + &txRollbackOk{}, + ) +} + +/* +Flow pauses the delivery of messages to consumers on this channel. Channels +are opened with flow control active, to open a channel with paused +deliveries immediately call this method with `false` after calling +Connection.Channel. + +When active is `false`, this method asks the server to temporarily pause deliveries +until called again with active as `true`. + +Channel.Get methods will not be affected by flow control. + +This method is not intended to act as window control. Use Channel.Qos to limit +the number of unacknowledged messages or bytes in flight instead. + +The server may also send us flow methods to throttle our publishings. A well +behaving publishing client should add a listener with Channel.NotifyFlow and +pause its publishings when `false` is sent on that channel. + +Note: RabbitMQ prefers to use TCP push back to control flow for all channels on +a connection, so under high volume scenarios, it's wise to open separate +Connections for publishings and deliveries. +*/ +func (ch *Channel) Flow(active bool) error { + return ch.call( + &channelFlow{Active: active}, + &channelFlowOk{}, + ) +} + +/* +Confirm puts this channel into confirm mode so that the client can ensure all +publishings have successfully been received by the server. After entering this +mode, the server will send a basic.ack or basic.nack message with the deliver +tag set to a 1 based incremental index corresponding to every publishing +received after the this method returns. + +Add a listener to Channel.NotifyPublish to respond to the Confirmations. If +Channel.NotifyPublish is not called, the Confirmations will be silently +ignored. + +The order of acknowledgments is not bound to the order of deliveries. + +Ack and Nack confirmations will arrive at some point in the future. + +Unroutable mandatory or immediate messages are acknowledged immediately after +any Channel.NotifyReturn listeners have been notified. Other messages are +acknowledged when all queues that should have the message routed to them have +either received acknowledgment of delivery or have enqueued the message, +persisting the message if necessary. + +When noWait is true, the client will not wait for a response. A channel +exception could occur if the server does not support this method. +*/ +func (ch *Channel) Confirm(noWait bool) error { + if err := ch.call( + &confirmSelect{Nowait: noWait}, + &confirmSelectOk{}, + ); err != nil { + return err + } + + ch.confirmM.Lock() + ch.confirming = true + ch.confirmM.Unlock() + + return nil +} + +/* +Recover redelivers all unacknowledged deliveries on this channel. + +When requeue is false, messages will be redelivered to the original consumer. + +When requeue is true, messages will be redelivered to any available consumer, +potentially including the original. + +If the deliveries cannot be recovered, an error will be returned and the channel +will be closed. + +Note: this method is not implemented on RabbitMQ, use Delivery.Nack instead + +Deprecated: This method is deprecated in RabbitMQ. RabbitMQ used Recover(true) +as a mechanism for consumers to tell the broker that they were ready for more +deliveries, back in 2008-2009. Support for this will be removed from RabbitMQ in +a future release. Use Nack() with requeue=true instead. +*/ +func (ch *Channel) Recover(requeue bool) error { + return ch.call( + &basicRecover{Requeue: requeue}, + &basicRecoverOk{}, + ) +} + +/* +Ack acknowledges a delivery by its delivery tag when having been consumed with +Channel.Consume or Channel.Get. + +Ack acknowledges all message received prior to the delivery tag when multiple +is true. + +See also Delivery.Ack +*/ +func (ch *Channel) Ack(tag uint64, multiple bool) error { + ch.m.Lock() + defer ch.m.Unlock() + + return ch.send(&basicAck{ + DeliveryTag: tag, + Multiple: multiple, + }) +} + +/* +Nack negatively acknowledges a delivery by its delivery tag. Prefer this +method to notify the server that you were not able to process this delivery and +it must be redelivered or dropped. + +See also Delivery.Nack +*/ +func (ch *Channel) Nack(tag uint64, multiple bool, requeue bool) error { + ch.m.Lock() + defer ch.m.Unlock() + + return ch.send(&basicNack{ + DeliveryTag: tag, + Multiple: multiple, + Requeue: requeue, + }) +} + +/* +Reject negatively acknowledges a delivery by its delivery tag. Prefer Nack +over Reject when communicating with a RabbitMQ server because you can Nack +multiple messages, reducing the amount of protocol messages to exchange. + +See also Delivery.Reject +*/ +func (ch *Channel) Reject(tag uint64, requeue bool) error { + ch.m.Lock() + defer ch.m.Unlock() + + return ch.send(&basicReject{ + DeliveryTag: tag, + Requeue: requeue, + }) +} + +// GetNextPublishSeqNo returns the sequence number of the next message to be +// published, when in confirm mode. +func (ch *Channel) GetNextPublishSeqNo() uint64 { + ch.confirms.m.Lock() + defer ch.confirms.m.Unlock() + + return ch.confirms.published + 1 +} diff --git a/vendor/github.com/rabbitmq/amqp091-go/connection.go b/vendor/github.com/rabbitmq/amqp091-go/connection.go new file mode 100644 index 0000000000..3d50d95580 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/connection.go @@ -0,0 +1,1099 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All Rights Reserved. +// Copyright (c) 2012-2021, Sean Treadway, SoundCloud Ltd. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package amqp091 + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "os" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + maxChannelMax = (2 << 15) - 1 + + defaultHeartbeat = 10 * time.Second + defaultConnectionTimeout = 30 * time.Second + defaultProduct = "AMQP 0.9.1 Client" + buildVersion = "1.8.1" + platform = "golang" + // Safer default that makes channel leaks a lot easier to spot + // before they create operational headaches. See https://github.com/rabbitmq/rabbitmq-server/issues/1593. + defaultChannelMax = (2 << 10) - 1 + defaultLocale = "en_US" +) + +// Config is used in DialConfig and Open to specify the desired tuning +// parameters used during a connection open handshake. The negotiated tuning +// will be stored in the returned connection's Config field. +type Config struct { + // The SASL mechanisms to try in the client request, and the successful + // mechanism used on the Connection object. + // If SASL is nil, PlainAuth from the URL is used. + SASL []Authentication + + // Vhost specifies the namespace of permissions, exchanges, queues and + // bindings on the server. Dial sets this to the path parsed from the URL. + Vhost string + + ChannelMax int // 0 max channels means 2^16 - 1 + FrameSize int // 0 max bytes means unlimited + Heartbeat time.Duration // less than 1s uses the server's interval + + // TLSClientConfig specifies the client configuration of the TLS connection + // when establishing a tls transport. + // If the URL uses an amqps scheme, then an empty tls.Config with the + // ServerName from the URL is used. + TLSClientConfig *tls.Config + + // Properties is table of properties that the client advertises to the server. + // This is an optional setting - if the application does not set this, + // the underlying library will use a generic set of client properties. + Properties Table + + // Connection locale that we expect to always be en_US + // Even though servers must return it as per the AMQP 0-9-1 spec, + // we are not aware of it being used other than to satisfy the spec requirements + Locale string + + // Dial returns a net.Conn prepared for a TLS handshake with TSLClientConfig, + // then an AMQP connection handshake. + // If Dial is nil, net.DialTimeout with a 30s connection and 30s deadline is + // used during TLS and AMQP handshaking. + Dial func(network, addr string) (net.Conn, error) +} + +// NewConnectionProperties creates an amqp.Table to be used as amqp.Config.Properties. +// +// Defaults to library-defined values. For empty properties, use make(amqp.Table) instead. +func NewConnectionProperties() Table { + return Table{ + "product": defaultProduct, + "version": buildVersion, + "platform": platform, + } +} + +// Connection manages the serialization and deserialization of frames from IO +// and dispatches the frames to the appropriate channel. All RPC methods and +// asynchronous Publishing, Delivery, Ack, Nack and Return messages are +// multiplexed on this channel. There must always be active receivers for +// every asynchronous message on this connection. +type Connection struct { + destructor sync.Once // shutdown once + sendM sync.Mutex // conn writer mutex + m sync.Mutex // struct field mutex + + conn io.ReadWriteCloser + + rpc chan message + writer *writer + sends chan time.Time // timestamps of each frame sent + deadlines chan readDeadliner // heartbeater updates read deadlines + + allocator *allocator // id generator valid after openTune + channels map[uint16]*Channel + + noNotify bool // true when we will never notify again + closes []chan *Error + blocks []chan Blocking + + errors chan *Error + + Config Config // The negotiated Config after connection.open + + Major int // Server's major version + Minor int // Server's minor version + Properties Table // Server properties + Locales []string // Server locales + + closed int32 // Will be 1 if the connection is closed, 0 otherwise. Should only be accessed as atomic +} + +type readDeadliner interface { + SetReadDeadline(time.Time) error +} + +// DefaultDial establishes a connection when config.Dial is not provided +func DefaultDial(connectionTimeout time.Duration) func(network, addr string) (net.Conn, error) { + return func(network, addr string) (net.Conn, error) { + conn, err := net.DialTimeout(network, addr, connectionTimeout) + if err != nil { + return nil, err + } + + // Heartbeating hasn't started yet, don't stall forever on a dead server. + // A deadline is set for TLS and AMQP handshaking. After AMQP is established, + // the deadline is cleared in openComplete. + if err := conn.SetDeadline(time.Now().Add(connectionTimeout)); err != nil { + return nil, err + } + + return conn, nil + } +} + +// Dial accepts a string in the AMQP URI format and returns a new Connection +// over TCP using PlainAuth. Defaults to a server heartbeat interval of 10 +// seconds and sets the handshake deadline to 30 seconds. After handshake, +// deadlines are cleared. +// +// Dial uses the zero value of tls.Config when it encounters an amqps:// +// scheme. It is equivalent to calling DialTLS(amqp, nil). +func Dial(url string) (*Connection, error) { + return DialConfig(url, Config{ + Heartbeat: defaultHeartbeat, + Locale: defaultLocale, + }) +} + +// DialTLS accepts a string in the AMQP URI format and returns a new Connection +// over TCP using PlainAuth. Defaults to a server heartbeat interval of 10 +// seconds and sets the initial read deadline to 30 seconds. +// +// DialTLS uses the provided tls.Config when encountering an amqps:// scheme. +func DialTLS(url string, amqps *tls.Config) (*Connection, error) { + return DialConfig(url, Config{ + Heartbeat: defaultHeartbeat, + TLSClientConfig: amqps, + Locale: defaultLocale, + }) +} + +// DialTLS_ExternalAuth accepts a string in the AMQP URI format and returns a +// new Connection over TCP using EXTERNAL auth. Defaults to a server heartbeat +// interval of 10 seconds and sets the initial read deadline to 30 seconds. +// +// This mechanism is used, when RabbitMQ is configured for EXTERNAL auth with +// ssl_cert_login plugin for userless/passwordless logons +// +// DialTLS_ExternalAuth uses the provided tls.Config when encountering an +// amqps:// scheme. +func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) { + return DialConfig(url, Config{ + Heartbeat: defaultHeartbeat, + TLSClientConfig: amqps, + SASL: []Authentication{&ExternalAuth{}}, + }) +} + +// DialConfig accepts a string in the AMQP URI format and a configuration for +// the transport and connection setup, returning a new Connection. Defaults to +// a server heartbeat interval of 10 seconds and sets the initial read deadline +// to 30 seconds. +func DialConfig(url string, config Config) (*Connection, error) { + var err error + var conn net.Conn + + uri, err := ParseURI(url) + if err != nil { + return nil, err + } + + if config.SASL == nil { + config.SASL = []Authentication{uri.PlainAuth()} + } + + if config.Vhost == "" { + config.Vhost = uri.Vhost + } + + addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10)) + + dialer := config.Dial + if dialer == nil { + dialer = DefaultDial(defaultConnectionTimeout) + } + + conn, err = dialer("tcp", addr) + if err != nil { + return nil, err + } + + if uri.Scheme == "amqps" { + if config.TLSClientConfig == nil { + tlsConfig, err := tlsConfigFromURI(uri) + if err != nil { + return nil, fmt.Errorf("create TLS config from URI: %w", err) + } + config.TLSClientConfig = tlsConfig + } + + // If ServerName has not been specified in TLSClientConfig, + // set it to the URI host used for this connection. + if config.TLSClientConfig.ServerName == "" { + config.TLSClientConfig.ServerName = uri.Host + } + + client := tls.Client(conn, config.TLSClientConfig) + if err := client.Handshake(); err != nil { + conn.Close() + return nil, err + } + + conn = client + } + + return Open(conn, config) +} + +/* +Open accepts an already established connection, or other io.ReadWriteCloser as +a transport. Use this method if you have established a TLS connection or wish +to use your own custom transport. +*/ +func Open(conn io.ReadWriteCloser, config Config) (*Connection, error) { + c := &Connection{ + conn: conn, + writer: &writer{bufio.NewWriter(conn)}, + channels: make(map[uint16]*Channel), + rpc: make(chan message), + sends: make(chan time.Time), + errors: make(chan *Error, 1), + deadlines: make(chan readDeadliner, 1), + } + go c.reader(conn) + return c, c.open(config) +} + +/* +UpdateSecret updates the secret used to authenticate this connection. It is used when +secrets have an expiration date and need to be renewed, like OAuth 2 tokens. + +It returns an error if the operation is not successful, or if the connection is closed. +*/ +func (c *Connection) UpdateSecret(newSecret, reason string) error { + if c.IsClosed() { + return ErrClosed + } + return c.call(&connectionUpdateSecret{ + NewSecret: newSecret, + Reason: reason, + }, &connectionUpdateSecretOk{}) +} + +/* +LocalAddr returns the local TCP peer address, or ":0" (the zero value of net.TCPAddr) +as a fallback default value if the underlying transport does not support LocalAddr(). +*/ +func (c *Connection) LocalAddr() net.Addr { + if conn, ok := c.conn.(interface { + LocalAddr() net.Addr + }); ok { + return conn.LocalAddr() + } + return &net.TCPAddr{} +} + +/* +RemoteAddr returns the remote TCP peer address, if known. +*/ +func (c *Connection) RemoteAddr() net.Addr { + if conn, ok := c.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return conn.RemoteAddr() + } + return &net.TCPAddr{} +} + +// ConnectionState returns basic TLS details of the underlying transport. +// Returns a zero value when the underlying connection does not implement +// ConnectionState() tls.ConnectionState. +func (c *Connection) ConnectionState() tls.ConnectionState { + if conn, ok := c.conn.(interface { + ConnectionState() tls.ConnectionState + }); ok { + return conn.ConnectionState() + } + return tls.ConnectionState{} +} + +/* +NotifyClose registers a listener for close events either initiated by an error +accompanying a connection.close method or by a normal shutdown. + +The chan provided will be closed when the Connection is closed and on a +graceful close, no error will be sent. + +In case of a non graceful close the error will be notified synchronously by the library +so that it will be necessary to consume the Channel from the caller in order to avoid deadlocks + +To reconnect after a transport or protocol error, register a listener here and +re-run your setup process. +*/ +func (c *Connection) NotifyClose(receiver chan *Error) chan *Error { + c.m.Lock() + defer c.m.Unlock() + + if c.noNotify { + close(receiver) + } else { + c.closes = append(c.closes, receiver) + } + + return receiver +} + +/* +NotifyBlocked registers a listener for RabbitMQ specific TCP flow control +method extensions connection.blocked and connection.unblocked. Flow control is +active with a reason when Blocking.Blocked is true. When a Connection is +blocked, all methods will block across all connections until server resources +become free again. + +This optional extension is supported by the server when the +"connection.blocked" server capability key is true. +*/ +func (c *Connection) NotifyBlocked(receiver chan Blocking) chan Blocking { + c.m.Lock() + defer c.m.Unlock() + + if c.noNotify { + close(receiver) + } else { + c.blocks = append(c.blocks, receiver) + } + + return receiver +} + +/* +Close requests and waits for the response to close the AMQP connection. + +It's advisable to use this message when publishing to ensure all kernel buffers +have been flushed on the server and client before exiting. + +An error indicates that server may not have received this request to close but +the connection should be treated as closed regardless. + +After returning from this call, all resources associated with this connection, +including the underlying io, Channels, Notify listeners and Channel consumers +will also be closed. +*/ +func (c *Connection) Close() error { + if c.IsClosed() { + return ErrClosed + } + + defer c.shutdown(nil) + return c.call( + &connectionClose{ + ReplyCode: replySuccess, + ReplyText: "kthxbai", + }, + &connectionCloseOk{}, + ) +} + +// CloseDeadline requests and waits for the response to close this AMQP connection. +// +// Accepts a deadline for waiting the server response. The deadline is passed +// to the low-level connection i.e. network socket. +// +// Regardless of the error returned, the connection is considered closed, and it +// should not be used after calling this function. +// +// In the event of an I/O timeout, connection-closed listeners are NOT informed. +// +// After returning from this call, all resources associated with this connection, +// including the underlying io, Channels, Notify listeners and Channel consumers +// will also be closed. +func (c *Connection) CloseDeadline(deadline time.Time) error { + if c.IsClosed() { + return ErrClosed + } + + defer c.shutdown(nil) + + err := c.setDeadline(deadline) + if err != nil { + return err + } + + return c.call( + &connectionClose{ + ReplyCode: replySuccess, + ReplyText: "kthxbai", + }, + &connectionCloseOk{}, + ) +} + +func (c *Connection) closeWith(err *Error) error { + if c.IsClosed() { + return ErrClosed + } + + defer c.shutdown(err) + + return c.call( + &connectionClose{ + ReplyCode: uint16(err.Code), + ReplyText: err.Reason, + }, + &connectionCloseOk{}, + ) +} + +// IsClosed returns true if the connection is marked as closed, otherwise false +// is returned. +func (c *Connection) IsClosed() bool { + return atomic.LoadInt32(&c.closed) == 1 +} + +// setDeadline is a wrapper to type assert Connection.conn and set an I/O +// deadline in the underlying TCP connection socket, by calling +// net.Conn.SetDeadline(). It returns an error, in case the type assertion fails, +// although this should never happen. +func (c *Connection) setDeadline(t time.Time) error { + con, ok := c.conn.(net.Conn) + if !ok { + return errInvalidTypeAssertion + } + return con.SetDeadline(t) +} + +func (c *Connection) send(f frame) error { + if c.IsClosed() { + return ErrClosed + } + + c.sendM.Lock() + err := c.writer.WriteFrame(f) + c.sendM.Unlock() + + if err != nil { + // shutdown could be re-entrant from signaling notify chans + go c.shutdown(&Error{ + Code: FrameError, + Reason: err.Error(), + }) + } else { + // Broadcast we sent a frame, reducing heartbeats, only + // if there is something that can receive - like a non-reentrant + // call or if the heartbeater isn't running + select { + case c.sends <- time.Now(): + default: + } + } + + return err +} + +// This method is intended to be used with sendUnflushed() to end a sequence +// of sendUnflushed() calls and flush the connection +func (c *Connection) endSendUnflushed() error { + c.sendM.Lock() + defer c.sendM.Unlock() + return c.flush() +} + +// sendUnflushed performs an *Unflushed* write. It is otherwise equivalent to +// send(), and we provide a separate flush() function to explicitly flush the +// buffer after all Frames are written. +// +// Why is this a thing? +// +// send() method uses writer.WriteFrame(), which will write the Frame then +// flush the buffer. For cases like the sendOpen() method on Channel, which +// sends multiple Frames (methodFrame, headerFrame, N x bodyFrame), flushing +// after each Frame is inefficient as it negates much of the benefit of using a +// buffered writer, and results in more syscalls than necessary. Flushing buffers +// after every frame can have a significant performance impact when sending +// (basicPublish) small messages, so this method performs an *Unflushed* write +// but is otherwise equivalent to send() method, and we provide a separate +// flush method to explicitly flush the buffer after all Frames are written. +func (c *Connection) sendUnflushed(f frame) error { + if c.IsClosed() { + return ErrClosed + } + + c.sendM.Lock() + err := c.writer.WriteFrameNoFlush(f) + c.sendM.Unlock() + + if err != nil { + // shutdown could be re-entrant from signaling notify chans + go c.shutdown(&Error{ + Code: FrameError, + Reason: err.Error(), + }) + } + + return err +} + +// This method is intended to be used with sendUnflushed() to explicitly flush +// the buffer after all required Frames have been written to the buffer. +func (c *Connection) flush() (err error) { + if buf, ok := c.writer.w.(*bufio.Writer); ok { + err = buf.Flush() + + // Moving send notifier to flush increases basicPublish for the small message + // case. As sendUnflushed + flush is used for the case of sending semantically + // related Frames (e.g. a Message like basicPublish) there is no real advantage + // to sending per Frame vice per "group of related Frames" and for the case of + // small messages time.Now() is (relatively) expensive. + if err == nil { + // Broadcast we sent a frame, reducing heartbeats, only + // if there is something that can receive - like a non-reentrant + // call or if the heartbeater isn't running + select { + case c.sends <- time.Now(): + default: + } + } + } + + return +} + +func (c *Connection) shutdown(err *Error) { + atomic.StoreInt32(&c.closed, 1) + + c.destructor.Do(func() { + c.m.Lock() + defer c.m.Unlock() + + if err != nil { + for _, c := range c.closes { + c <- err + } + c.errors <- err + } + // Shutdown handler goroutine can still receive the result. + close(c.errors) + + for _, c := range c.closes { + close(c) + } + + for _, c := range c.blocks { + close(c) + } + + // Shutdown the channel, but do not use closeChannel() as it calls + // releaseChannel() which requires the connection lock. + // + // Ranging over c.channels and calling releaseChannel() that mutates + // c.channels is racy - see commit 6063341 for an example. + for _, ch := range c.channels { + ch.shutdown(err) + } + + c.conn.Close() + + c.channels = nil + c.allocator = nil + c.noNotify = true + }) +} + +// All methods sent to the connection channel should be synchronous so we +// can handle them directly without a framing component +func (c *Connection) demux(f frame) { + if f.channel() == 0 { + c.dispatch0(f) + } else { + c.dispatchN(f) + } +} + +func (c *Connection) dispatch0(f frame) { + switch mf := f.(type) { + case *methodFrame: + switch m := mf.Method.(type) { + case *connectionClose: + // Send immediately as shutdown will close our side of the writer. + f := &methodFrame{ChannelId: 0, Method: &connectionCloseOk{}} + if err := c.send(f); err != nil { + Logger.Printf("error sending connectionCloseOk, error: %+v", err) + } + c.shutdown(newError(m.ReplyCode, m.ReplyText)) + case *connectionBlocked: + for _, c := range c.blocks { + c <- Blocking{Active: true, Reason: m.Reason} + } + case *connectionUnblocked: + for _, c := range c.blocks { + c <- Blocking{Active: false} + } + default: + c.rpc <- m + } + case *heartbeatFrame: + // kthx - all reads reset our deadline. so we can drop this + default: + // lolwat - channel0 only responds to methods and heartbeats + if err := c.closeWith(ErrUnexpectedFrame); err != nil { + Logger.Printf("error sending connectionCloseOk with ErrUnexpectedFrame, error: %+v", err) + } + } +} + +func (c *Connection) dispatchN(f frame) { + c.m.Lock() + channel, ok := c.channels[f.channel()] + if ok { + updateChannel(f, channel) + } else { + Logger.Printf("[debug] dropping frame, channel %d does not exist", f.channel()) + } + c.m.Unlock() + + // Note: this could result in concurrent dispatch depending on + // how channels are managed in an application + if ok { + channel.recv(channel, f) + } else { + c.dispatchClosed(f) + } +} + +// section 2.3.7: "When a peer decides to close a channel or connection, it +// sends a Close method. The receiving peer MUST respond to a Close with a +// Close-Ok, and then both parties can close their channel or connection. Note +// that if peers ignore Close, deadlock can happen when both peers send Close +// at the same time." +// +// When we don't have a channel, so we must respond with close-ok on a close +// method. This can happen between a channel exception on an asynchronous +// method like basic.publish and a synchronous close with channel.close. +// In that case, we'll get both a channel.close and channel.close-ok in any +// order. +func (c *Connection) dispatchClosed(f frame) { + // Only consider method frames, drop content/header frames + if mf, ok := f.(*methodFrame); ok { + switch mf.Method.(type) { + case *channelClose: + f := &methodFrame{ChannelId: f.channel(), Method: &channelCloseOk{}} + if err := c.send(f); err != nil { + Logger.Printf("error sending channelCloseOk, channel id: %d error: %+v", f.channel(), err) + } + case *channelCloseOk: + // we are already closed, so do nothing + default: + // unexpected method on closed channel + if err := c.closeWith(ErrClosed); err != nil { + Logger.Printf("error sending connectionCloseOk with ErrClosed, error: %+v", err) + } + } + } +} + +// Reads each frame off the IO and hand off to the connection object that +// will demux the streams and dispatch to one of the opened channels or +// handle on channel 0 (the connection channel). +func (c *Connection) reader(r io.Reader) { + buf := bufio.NewReader(r) + frames := &reader{buf} + conn, haveDeadliner := r.(readDeadliner) + + defer close(c.rpc) + + for { + frame, err := frames.ReadFrame() + + if err != nil { + c.shutdown(&Error{Code: FrameError, Reason: err.Error()}) + return + } + + c.demux(frame) + + if haveDeadliner { + select { + case c.deadlines <- conn: + default: + // On c.Close() c.heartbeater() might exit just before c.deadlines <- conn is called. + // Which results in this goroutine being stuck forever. + } + } + } +} + +// Ensures that at least one frame is being sent at the tuned interval with a +// jitter tolerance of 1s +func (c *Connection) heartbeater(interval time.Duration, done chan *Error) { + const maxServerHeartbeatsInFlight = 3 + + var sendTicks <-chan time.Time + if interval > 0 { + ticker := time.NewTicker(interval) + defer ticker.Stop() + sendTicks = ticker.C + } + + lastSent := time.Now() + + for { + select { + case at, stillSending := <-c.sends: + // When actively sending, depend on sent frames to reset server timer + if stillSending { + lastSent = at + } else { + return + } + + case at := <-sendTicks: + // When idle, fill the space with a heartbeat frame + if at.Sub(lastSent) > interval-time.Second { + if err := c.send(&heartbeatFrame{}); err != nil { + // send heartbeats even after close/closeOk so we + // tick until the connection starts erroring + return + } + } + + case conn := <-c.deadlines: + // When reading, reset our side of the deadline, if we've negotiated one with + // a deadline that covers at least 2 server heartbeats + if interval > 0 { + if err := conn.SetReadDeadline(time.Now().Add(maxServerHeartbeatsInFlight * interval)); err != nil { + var opErr *net.OpError + if !errors.As(err, &opErr) { + Logger.Printf("error setting read deadline in heartbeater: %+v", err) + return + } + } + } + + case <-done: + return + } + } +} + +// Convenience method to inspect the Connection.Properties["capabilities"] +// Table for server identified capabilities like "basic.ack" or +// "confirm.select". +func (c *Connection) isCapable(featureName string) bool { + capabilities, _ := c.Properties["capabilities"].(Table) + hasFeature, _ := capabilities[featureName].(bool) + return hasFeature +} + +// allocateChannel records but does not open a new channel with a unique id. +// This method is the initial part of the channel lifecycle and paired with +// releaseChannel +func (c *Connection) allocateChannel() (*Channel, error) { + c.m.Lock() + defer c.m.Unlock() + + if c.IsClosed() { + return nil, ErrClosed + } + + id, ok := c.allocator.next() + if !ok { + return nil, ErrChannelMax + } + + ch := newChannel(c, uint16(id)) + c.channels[uint16(id)] = ch + + return ch, nil +} + +// releaseChannel removes a channel from the registry as the final part of the +// channel lifecycle +func (c *Connection) releaseChannel(id uint16) { + c.m.Lock() + defer c.m.Unlock() + + if !c.IsClosed() { + delete(c.channels, id) + c.allocator.release(int(id)) + } +} + +// openChannel allocates and opens a channel, must be paired with closeChannel +func (c *Connection) openChannel() (*Channel, error) { + ch, err := c.allocateChannel() + if err != nil { + return nil, err + } + + if err := ch.open(); err != nil { + c.releaseChannel(ch.id) + return nil, err + } + return ch, nil +} + +// closeChannel releases and initiates a shutdown of the channel. All channel +// closures should be initiated here for proper channel lifecycle management on +// this connection. +func (c *Connection) closeChannel(ch *Channel, e *Error) { + ch.shutdown(e) + c.releaseChannel(ch.id) +} + +/* +Channel opens a unique, concurrent server channel to process the bulk of AMQP +messages. Any error from methods on this receiver will render the receiver +invalid and a new Channel should be opened. +*/ +func (c *Connection) Channel() (*Channel, error) { + return c.openChannel() +} + +func (c *Connection) call(req message, res ...message) error { + // Special case for when the protocol header frame is sent insted of a + // request method + if req != nil { + if err := c.send(&methodFrame{ChannelId: 0, Method: req}); err != nil { + return err + } + } + + msg, ok := <-c.rpc + if !ok { + err, errorsChanIsOpen := <-c.errors + if !errorsChanIsOpen { + return ErrClosed + } + return err + } + + // Try to match one of the result types + for _, try := range res { + if reflect.TypeOf(msg) == reflect.TypeOf(try) { + // *res = *msg + vres := reflect.ValueOf(try).Elem() + vmsg := reflect.ValueOf(msg).Elem() + vres.Set(vmsg) + return nil + } + } + return ErrCommandInvalid +} + +// Communication flow to open, use and close a connection. 'C:' are +// frames sent by the Client. 'S:' are frames sent by the Server. +// +// Connection = open-Connection *use-Connection close-Connection +// +// open-Connection = C:protocol-header +// S:START C:START-OK +// *challenge +// S:TUNE C:TUNE-OK +// C:OPEN S:OPEN-OK +// +// challenge = S:SECURE C:SECURE-OK +// +// use-Connection = *channel +// +// close-Connection = C:CLOSE S:CLOSE-OK +// S:CLOSE C:CLOSE-OK +func (c *Connection) open(config Config) error { + if err := c.send(&protocolHeader{}); err != nil { + return err + } + + return c.openStart(config) +} + +func (c *Connection) openStart(config Config) error { + start := &connectionStart{} + + if err := c.call(nil, start); err != nil { + return err + } + + c.Major = int(start.VersionMajor) + c.Minor = int(start.VersionMinor) + c.Properties = start.ServerProperties + c.Locales = strings.Split(start.Locales, " ") + + // eventually support challenge/response here by also responding to + // connectionSecure. + auth, ok := pickSASLMechanism(config.SASL, strings.Split(start.Mechanisms, " ")) + if !ok { + return ErrSASL + } + + // Save this mechanism off as the one we chose + c.Config.SASL = []Authentication{auth} + + // Set the connection locale to client locale + c.Config.Locale = config.Locale + + return c.openTune(config, auth) +} + +func (c *Connection) openTune(config Config, auth Authentication) error { + if len(config.Properties) == 0 { + config.Properties = NewConnectionProperties() + } + + config.Properties["capabilities"] = Table{ + "connection.blocked": true, + "consumer_cancel_notify": true, + "basic.nack": true, + "publisher_confirms": true, + } + + ok := &connectionStartOk{ + ClientProperties: config.Properties, + Mechanism: auth.Mechanism(), + Response: auth.Response(), + Locale: config.Locale, + } + tune := &connectionTune{} + + if err := c.call(ok, tune); err != nil { + // per spec, a connection can only be closed when it has been opened + // so at this point, we know it's an auth error, but the socket + // was closed instead. Return a meaningful error. + return ErrCredentials + } + + // Edge case that may race with c.shutdown() + // https://github.com/rabbitmq/amqp091-go/issues/170 + c.m.Lock() + + // When the server and client both use default 0, then the max channel is + // only limited by uint16. + c.Config.ChannelMax = pick(config.ChannelMax, int(tune.ChannelMax)) + if c.Config.ChannelMax == 0 { + c.Config.ChannelMax = defaultChannelMax + } + c.Config.ChannelMax = min(c.Config.ChannelMax, maxChannelMax) + + c.allocator = newAllocator(1, c.Config.ChannelMax) + + c.m.Unlock() + + // Frame size includes headers and end byte (len(payload)+8), even if + // this is less than FrameMinSize, use what the server sends because the + // alternative is to stop the handshake here. + c.Config.FrameSize = pick(config.FrameSize, int(tune.FrameMax)) + + // Save this off for resetDeadline() + c.Config.Heartbeat = time.Second * time.Duration(pick( + int(config.Heartbeat/time.Second), + int(tune.Heartbeat))) + + // "The client should start sending heartbeats after receiving a + // Connection.Tune method" + go c.heartbeater(c.Config.Heartbeat/2, c.NotifyClose(make(chan *Error, 1))) + + if err := c.send(&methodFrame{ + ChannelId: 0, + Method: &connectionTuneOk{ + ChannelMax: uint16(c.Config.ChannelMax), + FrameMax: uint32(c.Config.FrameSize), + Heartbeat: uint16(c.Config.Heartbeat / time.Second), + }, + }); err != nil { + return err + } + + return c.openVhost(config) +} + +func (c *Connection) openVhost(config Config) error { + req := &connectionOpen{VirtualHost: config.Vhost} + res := &connectionOpenOk{} + + if err := c.call(req, res); err != nil { + // Cannot be closed yet, but we know it's a vhost problem + return ErrVhost + } + + c.Config.Vhost = config.Vhost + + return c.openComplete() +} + +// openComplete performs any final Connection initialization dependent on the +// connection handshake and clears any state needed for TLS and AMQP handshaking. +func (c *Connection) openComplete() error { + // We clear the deadlines and let the heartbeater reset the read deadline if requested. + // RabbitMQ uses TCP flow control at this point for pushback so Writes can + // intentionally block. + if deadliner, ok := c.conn.(interface { + SetDeadline(time.Time) error + }); ok { + _ = deadliner.SetDeadline(time.Time{}) + } + + return nil +} + +// tlsConfigFromURI tries to create TLS configuration based on query parameters. +// Returns default (empty) config in case no suitable client cert and/or client key not provided. +// Returns error in case certificates can not be parsed. +func tlsConfigFromURI(uri URI) (*tls.Config, error) { + var certPool *x509.CertPool + if uri.CACertFile != "" { + data, err := os.ReadFile(uri.CACertFile) + if err != nil { + return nil, fmt.Errorf("read CA certificate: %w", err) + } + + certPool = x509.NewCertPool() + certPool.AppendCertsFromPEM(data) + } else if sysPool, err := x509.SystemCertPool(); err != nil { + return nil, fmt.Errorf("load system certificates: %w", err) + } else { + certPool = sysPool + } + + if uri.CertFile == "" || uri.KeyFile == "" { + // no client auth (mTLS), just server auth + return &tls.Config{ + RootCAs: certPool, + ServerName: uri.ServerName, + }, nil + } + + certificate, err := tls.LoadX509KeyPair(uri.CertFile, uri.KeyFile) + if err != nil { + return nil, fmt.Errorf("load client certificate: %w", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{certificate}, + RootCAs: certPool, + ServerName: uri.ServerName, + }, nil +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func pick(client, server int) int { + if client == 0 || server == 0 { + return max(client, server) + } + return min(client, server) +} diff --git a/vendor/github.com/rabbitmq/amqp091-go/spec091.go b/vendor/github.com/rabbitmq/amqp091-go/spec091.go new file mode 100644 index 0000000000..d86e753a95 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/spec091.go @@ -0,0 +1,3382 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All Rights Reserved. +// Copyright (c) 2012-2021, Sean Treadway, SoundCloud Ltd. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* GENERATED FILE - DO NOT EDIT */ +/* Rebuild from the spec/gen.go tool */ + +package amqp091 + +import ( + "encoding/binary" + "fmt" + "io" +) + +// Error codes that can be sent from the server during a connection or +// channel exception or used by the client to indicate a class of error like +// ErrCredentials. The text of the error is likely more interesting than +// these constants. +const ( + frameMethod = 1 + frameHeader = 2 + frameBody = 3 + frameHeartbeat = 8 + frameMinSize = 4096 + frameEnd = 206 + replySuccess = 200 + ContentTooLarge = 311 + NoRoute = 312 + NoConsumers = 313 + ConnectionForced = 320 + InvalidPath = 402 + AccessRefused = 403 + NotFound = 404 + ResourceLocked = 405 + PreconditionFailed = 406 + FrameError = 501 + SyntaxError = 502 + CommandInvalid = 503 + ChannelError = 504 + UnexpectedFrame = 505 + ResourceError = 506 + NotAllowed = 530 + NotImplemented = 540 + InternalError = 541 +) + +func isSoftExceptionCode(code int) bool { + switch code { + case 311: + return true + case 312: + return true + case 313: + return true + case 403: + return true + case 404: + return true + case 405: + return true + case 406: + return true + + } + return false +} + +type connectionStart struct { + VersionMajor byte + VersionMinor byte + ServerProperties Table + Mechanisms string + Locales string +} + +func (msg *connectionStart) id() (uint16, uint16) { + return 10, 10 +} + +func (msg *connectionStart) wait() bool { + return true +} + +func (msg *connectionStart) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.VersionMajor); err != nil { + return + } + if err = binary.Write(w, binary.BigEndian, msg.VersionMinor); err != nil { + return + } + + if err = writeTable(w, msg.ServerProperties); err != nil { + return + } + + if err = writeLongstr(w, msg.Mechanisms); err != nil { + return + } + if err = writeLongstr(w, msg.Locales); err != nil { + return + } + + return +} + +func (msg *connectionStart) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.VersionMajor); err != nil { + return + } + if err = binary.Read(r, binary.BigEndian, &msg.VersionMinor); err != nil { + return + } + + if msg.ServerProperties, err = readTable(r); err != nil { + return + } + + if msg.Mechanisms, err = readLongstr(r); err != nil { + return + } + if msg.Locales, err = readLongstr(r); err != nil { + return + } + + return +} + +type connectionStartOk struct { + ClientProperties Table + Mechanism string + Response string + Locale string +} + +func (msg *connectionStartOk) id() (uint16, uint16) { + return 10, 11 +} + +func (msg *connectionStartOk) wait() bool { + return true +} + +func (msg *connectionStartOk) write(w io.Writer) (err error) { + + if err = writeTable(w, msg.ClientProperties); err != nil { + return + } + + if err = writeShortstr(w, msg.Mechanism); err != nil { + return + } + + if err = writeLongstr(w, msg.Response); err != nil { + return + } + + if err = writeShortstr(w, msg.Locale); err != nil { + return + } + + return +} + +func (msg *connectionStartOk) read(r io.Reader) (err error) { + + if msg.ClientProperties, err = readTable(r); err != nil { + return + } + + if msg.Mechanism, err = readShortstr(r); err != nil { + return + } + + if msg.Response, err = readLongstr(r); err != nil { + return + } + + if msg.Locale, err = readShortstr(r); err != nil { + return + } + + return +} + +type connectionSecure struct { + Challenge string +} + +func (msg *connectionSecure) id() (uint16, uint16) { + return 10, 20 +} + +func (msg *connectionSecure) wait() bool { + return true +} + +func (msg *connectionSecure) write(w io.Writer) (err error) { + + if err = writeLongstr(w, msg.Challenge); err != nil { + return + } + + return +} + +func (msg *connectionSecure) read(r io.Reader) (err error) { + + if msg.Challenge, err = readLongstr(r); err != nil { + return + } + + return +} + +type connectionSecureOk struct { + Response string +} + +func (msg *connectionSecureOk) id() (uint16, uint16) { + return 10, 21 +} + +func (msg *connectionSecureOk) wait() bool { + return true +} + +func (msg *connectionSecureOk) write(w io.Writer) (err error) { + + if err = writeLongstr(w, msg.Response); err != nil { + return + } + + return +} + +func (msg *connectionSecureOk) read(r io.Reader) (err error) { + + if msg.Response, err = readLongstr(r); err != nil { + return + } + + return +} + +type connectionTune struct { + ChannelMax uint16 + FrameMax uint32 + Heartbeat uint16 +} + +func (msg *connectionTune) id() (uint16, uint16) { + return 10, 30 +} + +func (msg *connectionTune) wait() bool { + return true +} + +func (msg *connectionTune) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.ChannelMax); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.FrameMax); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.Heartbeat); err != nil { + return + } + + return +} + +func (msg *connectionTune) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.ChannelMax); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.FrameMax); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.Heartbeat); err != nil { + return + } + + return +} + +type connectionTuneOk struct { + ChannelMax uint16 + FrameMax uint32 + Heartbeat uint16 +} + +func (msg *connectionTuneOk) id() (uint16, uint16) { + return 10, 31 +} + +func (msg *connectionTuneOk) wait() bool { + return true +} + +func (msg *connectionTuneOk) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.ChannelMax); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.FrameMax); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.Heartbeat); err != nil { + return + } + + return +} + +func (msg *connectionTuneOk) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.ChannelMax); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.FrameMax); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.Heartbeat); err != nil { + return + } + + return +} + +type connectionOpen struct { + VirtualHost string + reserved1 string + reserved2 bool +} + +func (msg *connectionOpen) id() (uint16, uint16) { + return 10, 40 +} + +func (msg *connectionOpen) wait() bool { + return true +} + +func (msg *connectionOpen) write(w io.Writer) (err error) { + var bits byte + + if err = writeShortstr(w, msg.VirtualHost); err != nil { + return + } + if err = writeShortstr(w, msg.reserved1); err != nil { + return + } + + if msg.reserved2 { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *connectionOpen) read(r io.Reader) (err error) { + var bits byte + + if msg.VirtualHost, err = readShortstr(r); err != nil { + return + } + if msg.reserved1, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.reserved2 = (bits&(1<<0) > 0) + + return +} + +type connectionOpenOk struct { + reserved1 string +} + +func (msg *connectionOpenOk) id() (uint16, uint16) { + return 10, 41 +} + +func (msg *connectionOpenOk) wait() bool { + return true +} + +func (msg *connectionOpenOk) write(w io.Writer) (err error) { + + if err = writeShortstr(w, msg.reserved1); err != nil { + return + } + + return +} + +func (msg *connectionOpenOk) read(r io.Reader) (err error) { + + if msg.reserved1, err = readShortstr(r); err != nil { + return + } + + return +} + +type connectionClose struct { + ReplyCode uint16 + ReplyText string + ClassId uint16 + MethodId uint16 +} + +func (msg *connectionClose) id() (uint16, uint16) { + return 10, 50 +} + +func (msg *connectionClose) wait() bool { + return true +} + +func (msg *connectionClose) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.ReplyCode); err != nil { + return + } + + if err = writeShortstr(w, msg.ReplyText); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.ClassId); err != nil { + return + } + if err = binary.Write(w, binary.BigEndian, msg.MethodId); err != nil { + return + } + + return +} + +func (msg *connectionClose) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.ReplyCode); err != nil { + return + } + + if msg.ReplyText, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.ClassId); err != nil { + return + } + if err = binary.Read(r, binary.BigEndian, &msg.MethodId); err != nil { + return + } + + return +} + +type connectionCloseOk struct { +} + +func (msg *connectionCloseOk) id() (uint16, uint16) { + return 10, 51 +} + +func (msg *connectionCloseOk) wait() bool { + return true +} + +func (msg *connectionCloseOk) write(w io.Writer) (err error) { + + return +} + +func (msg *connectionCloseOk) read(r io.Reader) (err error) { + + return +} + +type connectionBlocked struct { + Reason string +} + +func (msg *connectionBlocked) id() (uint16, uint16) { + return 10, 60 +} + +func (msg *connectionBlocked) wait() bool { + return false +} + +func (msg *connectionBlocked) write(w io.Writer) (err error) { + + if err = writeShortstr(w, msg.Reason); err != nil { + return + } + + return +} + +func (msg *connectionBlocked) read(r io.Reader) (err error) { + + if msg.Reason, err = readShortstr(r); err != nil { + return + } + + return +} + +type connectionUnblocked struct { +} + +func (msg *connectionUnblocked) id() (uint16, uint16) { + return 10, 61 +} + +func (msg *connectionUnblocked) wait() bool { + return false +} + +func (msg *connectionUnblocked) write(w io.Writer) (err error) { + + return +} + +func (msg *connectionUnblocked) read(r io.Reader) (err error) { + + return +} + +type connectionUpdateSecret struct { + NewSecret string + Reason string +} + +func (msg *connectionUpdateSecret) id() (uint16, uint16) { + return 10, 70 +} + +func (msg *connectionUpdateSecret) wait() bool { + return true +} + +func (msg *connectionUpdateSecret) write(w io.Writer) (err error) { + + if err = writeLongstr(w, msg.NewSecret); err != nil { + return + } + + if err = writeShortstr(w, msg.Reason); err != nil { + return + } + + return +} + +func (msg *connectionUpdateSecret) read(r io.Reader) (err error) { + + if msg.NewSecret, err = readLongstr(r); err != nil { + return + } + + if msg.Reason, err = readShortstr(r); err != nil { + return + } + + return +} + +type connectionUpdateSecretOk struct { +} + +func (msg *connectionUpdateSecretOk) id() (uint16, uint16) { + return 10, 71 +} + +func (msg *connectionUpdateSecretOk) wait() bool { + return true +} + +func (msg *connectionUpdateSecretOk) write(w io.Writer) (err error) { + + return +} + +func (msg *connectionUpdateSecretOk) read(r io.Reader) (err error) { + + return +} + +type channelOpen struct { + reserved1 string +} + +func (msg *channelOpen) id() (uint16, uint16) { + return 20, 10 +} + +func (msg *channelOpen) wait() bool { + return true +} + +func (msg *channelOpen) write(w io.Writer) (err error) { + + if err = writeShortstr(w, msg.reserved1); err != nil { + return + } + + return +} + +func (msg *channelOpen) read(r io.Reader) (err error) { + + if msg.reserved1, err = readShortstr(r); err != nil { + return + } + + return +} + +type channelOpenOk struct { + reserved1 string +} + +func (msg *channelOpenOk) id() (uint16, uint16) { + return 20, 11 +} + +func (msg *channelOpenOk) wait() bool { + return true +} + +func (msg *channelOpenOk) write(w io.Writer) (err error) { + + if err = writeLongstr(w, msg.reserved1); err != nil { + return + } + + return +} + +func (msg *channelOpenOk) read(r io.Reader) (err error) { + + if msg.reserved1, err = readLongstr(r); err != nil { + return + } + + return +} + +type channelFlow struct { + Active bool +} + +func (msg *channelFlow) id() (uint16, uint16) { + return 20, 20 +} + +func (msg *channelFlow) wait() bool { + return true +} + +func (msg *channelFlow) write(w io.Writer) (err error) { + var bits byte + + if msg.Active { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *channelFlow) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Active = (bits&(1<<0) > 0) + + return +} + +type channelFlowOk struct { + Active bool +} + +func (msg *channelFlowOk) id() (uint16, uint16) { + return 20, 21 +} + +func (msg *channelFlowOk) wait() bool { + return false +} + +func (msg *channelFlowOk) write(w io.Writer) (err error) { + var bits byte + + if msg.Active { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *channelFlowOk) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Active = (bits&(1<<0) > 0) + + return +} + +type channelClose struct { + ReplyCode uint16 + ReplyText string + ClassId uint16 + MethodId uint16 +} + +func (msg *channelClose) id() (uint16, uint16) { + return 20, 40 +} + +func (msg *channelClose) wait() bool { + return true +} + +func (msg *channelClose) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.ReplyCode); err != nil { + return + } + + if err = writeShortstr(w, msg.ReplyText); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.ClassId); err != nil { + return + } + if err = binary.Write(w, binary.BigEndian, msg.MethodId); err != nil { + return + } + + return +} + +func (msg *channelClose) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.ReplyCode); err != nil { + return + } + + if msg.ReplyText, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.ClassId); err != nil { + return + } + if err = binary.Read(r, binary.BigEndian, &msg.MethodId); err != nil { + return + } + + return +} + +type channelCloseOk struct { +} + +func (msg *channelCloseOk) id() (uint16, uint16) { + return 20, 41 +} + +func (msg *channelCloseOk) wait() bool { + return true +} + +func (msg *channelCloseOk) write(w io.Writer) (err error) { + + return +} + +func (msg *channelCloseOk) read(r io.Reader) (err error) { + + return +} + +type exchangeDeclare struct { + reserved1 uint16 + Exchange string + Type string + Passive bool + Durable bool + AutoDelete bool + Internal bool + NoWait bool + Arguments Table +} + +func (msg *exchangeDeclare) id() (uint16, uint16) { + return 40, 10 +} + +func (msg *exchangeDeclare) wait() bool { + return true && !msg.NoWait +} + +func (msg *exchangeDeclare) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + if err = writeShortstr(w, msg.Type); err != nil { + return + } + + if msg.Passive { + bits |= 1 << 0 + } + + if msg.Durable { + bits |= 1 << 1 + } + + if msg.AutoDelete { + bits |= 1 << 2 + } + + if msg.Internal { + bits |= 1 << 3 + } + + if msg.NoWait { + bits |= 1 << 4 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeTable(w, msg.Arguments); err != nil { + return + } + + return +} + +func (msg *exchangeDeclare) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + if msg.Type, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Passive = (bits&(1<<0) > 0) + msg.Durable = (bits&(1<<1) > 0) + msg.AutoDelete = (bits&(1<<2) > 0) + msg.Internal = (bits&(1<<3) > 0) + msg.NoWait = (bits&(1<<4) > 0) + + if msg.Arguments, err = readTable(r); err != nil { + return + } + + return +} + +type exchangeDeclareOk struct { +} + +func (msg *exchangeDeclareOk) id() (uint16, uint16) { + return 40, 11 +} + +func (msg *exchangeDeclareOk) wait() bool { + return true +} + +func (msg *exchangeDeclareOk) write(w io.Writer) (err error) { + + return +} + +func (msg *exchangeDeclareOk) read(r io.Reader) (err error) { + + return +} + +type exchangeDelete struct { + reserved1 uint16 + Exchange string + IfUnused bool + NoWait bool +} + +func (msg *exchangeDelete) id() (uint16, uint16) { + return 40, 20 +} + +func (msg *exchangeDelete) wait() bool { + return true && !msg.NoWait +} + +func (msg *exchangeDelete) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + + if msg.IfUnused { + bits |= 1 << 0 + } + + if msg.NoWait { + bits |= 1 << 1 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *exchangeDelete) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.IfUnused = (bits&(1<<0) > 0) + msg.NoWait = (bits&(1<<1) > 0) + + return +} + +type exchangeDeleteOk struct { +} + +func (msg *exchangeDeleteOk) id() (uint16, uint16) { + return 40, 21 +} + +func (msg *exchangeDeleteOk) wait() bool { + return true +} + +func (msg *exchangeDeleteOk) write(w io.Writer) (err error) { + + return +} + +func (msg *exchangeDeleteOk) read(r io.Reader) (err error) { + + return +} + +type exchangeBind struct { + reserved1 uint16 + Destination string + Source string + RoutingKey string + NoWait bool + Arguments Table +} + +func (msg *exchangeBind) id() (uint16, uint16) { + return 40, 30 +} + +func (msg *exchangeBind) wait() bool { + return true && !msg.NoWait +} + +func (msg *exchangeBind) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Destination); err != nil { + return + } + if err = writeShortstr(w, msg.Source); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + if msg.NoWait { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeTable(w, msg.Arguments); err != nil { + return + } + + return +} + +func (msg *exchangeBind) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Destination, err = readShortstr(r); err != nil { + return + } + if msg.Source, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.NoWait = (bits&(1<<0) > 0) + + if msg.Arguments, err = readTable(r); err != nil { + return + } + + return +} + +type exchangeBindOk struct { +} + +func (msg *exchangeBindOk) id() (uint16, uint16) { + return 40, 31 +} + +func (msg *exchangeBindOk) wait() bool { + return true +} + +func (msg *exchangeBindOk) write(w io.Writer) (err error) { + + return +} + +func (msg *exchangeBindOk) read(r io.Reader) (err error) { + + return +} + +type exchangeUnbind struct { + reserved1 uint16 + Destination string + Source string + RoutingKey string + NoWait bool + Arguments Table +} + +func (msg *exchangeUnbind) id() (uint16, uint16) { + return 40, 40 +} + +func (msg *exchangeUnbind) wait() bool { + return true && !msg.NoWait +} + +func (msg *exchangeUnbind) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Destination); err != nil { + return + } + if err = writeShortstr(w, msg.Source); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + if msg.NoWait { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeTable(w, msg.Arguments); err != nil { + return + } + + return +} + +func (msg *exchangeUnbind) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Destination, err = readShortstr(r); err != nil { + return + } + if msg.Source, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.NoWait = (bits&(1<<0) > 0) + + if msg.Arguments, err = readTable(r); err != nil { + return + } + + return +} + +type exchangeUnbindOk struct { +} + +func (msg *exchangeUnbindOk) id() (uint16, uint16) { + return 40, 51 +} + +func (msg *exchangeUnbindOk) wait() bool { + return true +} + +func (msg *exchangeUnbindOk) write(w io.Writer) (err error) { + + return +} + +func (msg *exchangeUnbindOk) read(r io.Reader) (err error) { + + return +} + +type queueDeclare struct { + reserved1 uint16 + Queue string + Passive bool + Durable bool + Exclusive bool + AutoDelete bool + NoWait bool + Arguments Table +} + +func (msg *queueDeclare) id() (uint16, uint16) { + return 50, 10 +} + +func (msg *queueDeclare) wait() bool { + return true && !msg.NoWait +} + +func (msg *queueDeclare) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + + if msg.Passive { + bits |= 1 << 0 + } + + if msg.Durable { + bits |= 1 << 1 + } + + if msg.Exclusive { + bits |= 1 << 2 + } + + if msg.AutoDelete { + bits |= 1 << 3 + } + + if msg.NoWait { + bits |= 1 << 4 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeTable(w, msg.Arguments); err != nil { + return + } + + return +} + +func (msg *queueDeclare) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Passive = (bits&(1<<0) > 0) + msg.Durable = (bits&(1<<1) > 0) + msg.Exclusive = (bits&(1<<2) > 0) + msg.AutoDelete = (bits&(1<<3) > 0) + msg.NoWait = (bits&(1<<4) > 0) + + if msg.Arguments, err = readTable(r); err != nil { + return + } + + return +} + +type queueDeclareOk struct { + Queue string + MessageCount uint32 + ConsumerCount uint32 +} + +func (msg *queueDeclareOk) id() (uint16, uint16) { + return 50, 11 +} + +func (msg *queueDeclareOk) wait() bool { + return true +} + +func (msg *queueDeclareOk) write(w io.Writer) (err error) { + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.MessageCount); err != nil { + return + } + if err = binary.Write(w, binary.BigEndian, msg.ConsumerCount); err != nil { + return + } + + return +} + +func (msg *queueDeclareOk) read(r io.Reader) (err error) { + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.MessageCount); err != nil { + return + } + if err = binary.Read(r, binary.BigEndian, &msg.ConsumerCount); err != nil { + return + } + + return +} + +type queueBind struct { + reserved1 uint16 + Queue string + Exchange string + RoutingKey string + NoWait bool + Arguments Table +} + +func (msg *queueBind) id() (uint16, uint16) { + return 50, 20 +} + +func (msg *queueBind) wait() bool { + return true && !msg.NoWait +} + +func (msg *queueBind) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + if msg.NoWait { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeTable(w, msg.Arguments); err != nil { + return + } + + return +} + +func (msg *queueBind) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.NoWait = (bits&(1<<0) > 0) + + if msg.Arguments, err = readTable(r); err != nil { + return + } + + return +} + +type queueBindOk struct { +} + +func (msg *queueBindOk) id() (uint16, uint16) { + return 50, 21 +} + +func (msg *queueBindOk) wait() bool { + return true +} + +func (msg *queueBindOk) write(w io.Writer) (err error) { + + return +} + +func (msg *queueBindOk) read(r io.Reader) (err error) { + + return +} + +type queueUnbind struct { + reserved1 uint16 + Queue string + Exchange string + RoutingKey string + Arguments Table +} + +func (msg *queueUnbind) id() (uint16, uint16) { + return 50, 50 +} + +func (msg *queueUnbind) wait() bool { + return true +} + +func (msg *queueUnbind) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + if err = writeTable(w, msg.Arguments); err != nil { + return + } + + return +} + +func (msg *queueUnbind) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + if msg.Arguments, err = readTable(r); err != nil { + return + } + + return +} + +type queueUnbindOk struct { +} + +func (msg *queueUnbindOk) id() (uint16, uint16) { + return 50, 51 +} + +func (msg *queueUnbindOk) wait() bool { + return true +} + +func (msg *queueUnbindOk) write(w io.Writer) (err error) { + + return +} + +func (msg *queueUnbindOk) read(r io.Reader) (err error) { + + return +} + +type queuePurge struct { + reserved1 uint16 + Queue string + NoWait bool +} + +func (msg *queuePurge) id() (uint16, uint16) { + return 50, 30 +} + +func (msg *queuePurge) wait() bool { + return true && !msg.NoWait +} + +func (msg *queuePurge) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + + if msg.NoWait { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *queuePurge) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.NoWait = (bits&(1<<0) > 0) + + return +} + +type queuePurgeOk struct { + MessageCount uint32 +} + +func (msg *queuePurgeOk) id() (uint16, uint16) { + return 50, 31 +} + +func (msg *queuePurgeOk) wait() bool { + return true +} + +func (msg *queuePurgeOk) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.MessageCount); err != nil { + return + } + + return +} + +func (msg *queuePurgeOk) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.MessageCount); err != nil { + return + } + + return +} + +type queueDelete struct { + reserved1 uint16 + Queue string + IfUnused bool + IfEmpty bool + NoWait bool +} + +func (msg *queueDelete) id() (uint16, uint16) { + return 50, 40 +} + +func (msg *queueDelete) wait() bool { + return true && !msg.NoWait +} + +func (msg *queueDelete) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + + if msg.IfUnused { + bits |= 1 << 0 + } + + if msg.IfEmpty { + bits |= 1 << 1 + } + + if msg.NoWait { + bits |= 1 << 2 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *queueDelete) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.IfUnused = (bits&(1<<0) > 0) + msg.IfEmpty = (bits&(1<<1) > 0) + msg.NoWait = (bits&(1<<2) > 0) + + return +} + +type queueDeleteOk struct { + MessageCount uint32 +} + +func (msg *queueDeleteOk) id() (uint16, uint16) { + return 50, 41 +} + +func (msg *queueDeleteOk) wait() bool { + return true +} + +func (msg *queueDeleteOk) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.MessageCount); err != nil { + return + } + + return +} + +func (msg *queueDeleteOk) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.MessageCount); err != nil { + return + } + + return +} + +type basicQos struct { + PrefetchSize uint32 + PrefetchCount uint16 + Global bool +} + +func (msg *basicQos) id() (uint16, uint16) { + return 60, 10 +} + +func (msg *basicQos) wait() bool { + return true +} + +func (msg *basicQos) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.PrefetchSize); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.PrefetchCount); err != nil { + return + } + + if msg.Global { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicQos) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.PrefetchSize); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.PrefetchCount); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Global = (bits&(1<<0) > 0) + + return +} + +type basicQosOk struct { +} + +func (msg *basicQosOk) id() (uint16, uint16) { + return 60, 11 +} + +func (msg *basicQosOk) wait() bool { + return true +} + +func (msg *basicQosOk) write(w io.Writer) (err error) { + + return +} + +func (msg *basicQosOk) read(r io.Reader) (err error) { + + return +} + +type basicConsume struct { + reserved1 uint16 + Queue string + ConsumerTag string + NoLocal bool + NoAck bool + Exclusive bool + NoWait bool + Arguments Table +} + +func (msg *basicConsume) id() (uint16, uint16) { + return 60, 20 +} + +func (msg *basicConsume) wait() bool { + return true && !msg.NoWait +} + +func (msg *basicConsume) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + if err = writeShortstr(w, msg.ConsumerTag); err != nil { + return + } + + if msg.NoLocal { + bits |= 1 << 0 + } + + if msg.NoAck { + bits |= 1 << 1 + } + + if msg.Exclusive { + bits |= 1 << 2 + } + + if msg.NoWait { + bits |= 1 << 3 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeTable(w, msg.Arguments); err != nil { + return + } + + return +} + +func (msg *basicConsume) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + if msg.ConsumerTag, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.NoLocal = (bits&(1<<0) > 0) + msg.NoAck = (bits&(1<<1) > 0) + msg.Exclusive = (bits&(1<<2) > 0) + msg.NoWait = (bits&(1<<3) > 0) + + if msg.Arguments, err = readTable(r); err != nil { + return + } + + return +} + +type basicConsumeOk struct { + ConsumerTag string +} + +func (msg *basicConsumeOk) id() (uint16, uint16) { + return 60, 21 +} + +func (msg *basicConsumeOk) wait() bool { + return true +} + +func (msg *basicConsumeOk) write(w io.Writer) (err error) { + + if err = writeShortstr(w, msg.ConsumerTag); err != nil { + return + } + + return +} + +func (msg *basicConsumeOk) read(r io.Reader) (err error) { + + if msg.ConsumerTag, err = readShortstr(r); err != nil { + return + } + + return +} + +type basicCancel struct { + ConsumerTag string + NoWait bool +} + +func (msg *basicCancel) id() (uint16, uint16) { + return 60, 30 +} + +func (msg *basicCancel) wait() bool { + return true && !msg.NoWait +} + +func (msg *basicCancel) write(w io.Writer) (err error) { + var bits byte + + if err = writeShortstr(w, msg.ConsumerTag); err != nil { + return + } + + if msg.NoWait { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicCancel) read(r io.Reader) (err error) { + var bits byte + + if msg.ConsumerTag, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.NoWait = (bits&(1<<0) > 0) + + return +} + +type basicCancelOk struct { + ConsumerTag string +} + +func (msg *basicCancelOk) id() (uint16, uint16) { + return 60, 31 +} + +func (msg *basicCancelOk) wait() bool { + return true +} + +func (msg *basicCancelOk) write(w io.Writer) (err error) { + + if err = writeShortstr(w, msg.ConsumerTag); err != nil { + return + } + + return +} + +func (msg *basicCancelOk) read(r io.Reader) (err error) { + + if msg.ConsumerTag, err = readShortstr(r); err != nil { + return + } + + return +} + +type basicPublish struct { + reserved1 uint16 + Exchange string + RoutingKey string + Mandatory bool + Immediate bool + Properties properties + Body []byte +} + +func (msg *basicPublish) id() (uint16, uint16) { + return 60, 40 +} + +func (msg *basicPublish) wait() bool { + return false +} + +func (msg *basicPublish) getContent() (properties, []byte) { + return msg.Properties, msg.Body +} + +func (msg *basicPublish) setContent(props properties, body []byte) { + msg.Properties, msg.Body = props, body +} + +func (msg *basicPublish) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + if msg.Mandatory { + bits |= 1 << 0 + } + + if msg.Immediate { + bits |= 1 << 1 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicPublish) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Mandatory = (bits&(1<<0) > 0) + msg.Immediate = (bits&(1<<1) > 0) + + return +} + +type basicReturn struct { + ReplyCode uint16 + ReplyText string + Exchange string + RoutingKey string + Properties properties + Body []byte +} + +func (msg *basicReturn) id() (uint16, uint16) { + return 60, 50 +} + +func (msg *basicReturn) wait() bool { + return false +} + +func (msg *basicReturn) getContent() (properties, []byte) { + return msg.Properties, msg.Body +} + +func (msg *basicReturn) setContent(props properties, body []byte) { + msg.Properties, msg.Body = props, body +} + +func (msg *basicReturn) write(w io.Writer) (err error) { + + if err = binary.Write(w, binary.BigEndian, msg.ReplyCode); err != nil { + return + } + + if err = writeShortstr(w, msg.ReplyText); err != nil { + return + } + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + return +} + +func (msg *basicReturn) read(r io.Reader) (err error) { + + if err = binary.Read(r, binary.BigEndian, &msg.ReplyCode); err != nil { + return + } + + if msg.ReplyText, err = readShortstr(r); err != nil { + return + } + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + return +} + +type basicDeliver struct { + ConsumerTag string + DeliveryTag uint64 + Redelivered bool + Exchange string + RoutingKey string + Properties properties + Body []byte +} + +func (msg *basicDeliver) id() (uint16, uint16) { + return 60, 60 +} + +func (msg *basicDeliver) wait() bool { + return false +} + +func (msg *basicDeliver) getContent() (properties, []byte) { + return msg.Properties, msg.Body +} + +func (msg *basicDeliver) setContent(props properties, body []byte) { + msg.Properties, msg.Body = props, body +} + +func (msg *basicDeliver) write(w io.Writer) (err error) { + var bits byte + + if err = writeShortstr(w, msg.ConsumerTag); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.DeliveryTag); err != nil { + return + } + + if msg.Redelivered { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + return +} + +func (msg *basicDeliver) read(r io.Reader) (err error) { + var bits byte + + if msg.ConsumerTag, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.DeliveryTag); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Redelivered = (bits&(1<<0) > 0) + + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + return +} + +type basicGet struct { + reserved1 uint16 + Queue string + NoAck bool +} + +func (msg *basicGet) id() (uint16, uint16) { + return 60, 70 +} + +func (msg *basicGet) wait() bool { + return true +} + +func (msg *basicGet) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.reserved1); err != nil { + return + } + + if err = writeShortstr(w, msg.Queue); err != nil { + return + } + + if msg.NoAck { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicGet) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.reserved1); err != nil { + return + } + + if msg.Queue, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.NoAck = (bits&(1<<0) > 0) + + return +} + +type basicGetOk struct { + DeliveryTag uint64 + Redelivered bool + Exchange string + RoutingKey string + MessageCount uint32 + Properties properties + Body []byte +} + +func (msg *basicGetOk) id() (uint16, uint16) { + return 60, 71 +} + +func (msg *basicGetOk) wait() bool { + return true +} + +func (msg *basicGetOk) getContent() (properties, []byte) { + return msg.Properties, msg.Body +} + +func (msg *basicGetOk) setContent(props properties, body []byte) { + msg.Properties, msg.Body = props, body +} + +func (msg *basicGetOk) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.DeliveryTag); err != nil { + return + } + + if msg.Redelivered { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + if err = writeShortstr(w, msg.Exchange); err != nil { + return + } + if err = writeShortstr(w, msg.RoutingKey); err != nil { + return + } + + if err = binary.Write(w, binary.BigEndian, msg.MessageCount); err != nil { + return + } + + return +} + +func (msg *basicGetOk) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.DeliveryTag); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Redelivered = (bits&(1<<0) > 0) + + if msg.Exchange, err = readShortstr(r); err != nil { + return + } + if msg.RoutingKey, err = readShortstr(r); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &msg.MessageCount); err != nil { + return + } + + return +} + +type basicGetEmpty struct { + reserved1 string +} + +func (msg *basicGetEmpty) id() (uint16, uint16) { + return 60, 72 +} + +func (msg *basicGetEmpty) wait() bool { + return true +} + +func (msg *basicGetEmpty) write(w io.Writer) (err error) { + + if err = writeShortstr(w, msg.reserved1); err != nil { + return + } + + return +} + +func (msg *basicGetEmpty) read(r io.Reader) (err error) { + + if msg.reserved1, err = readShortstr(r); err != nil { + return + } + + return +} + +type basicAck struct { + DeliveryTag uint64 + Multiple bool +} + +func (msg *basicAck) id() (uint16, uint16) { + return 60, 80 +} + +func (msg *basicAck) wait() bool { + return false +} + +func (msg *basicAck) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.DeliveryTag); err != nil { + return + } + + if msg.Multiple { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicAck) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.DeliveryTag); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Multiple = (bits&(1<<0) > 0) + + return +} + +type basicReject struct { + DeliveryTag uint64 + Requeue bool +} + +func (msg *basicReject) id() (uint16, uint16) { + return 60, 90 +} + +func (msg *basicReject) wait() bool { + return false +} + +func (msg *basicReject) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.DeliveryTag); err != nil { + return + } + + if msg.Requeue { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicReject) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.DeliveryTag); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Requeue = (bits&(1<<0) > 0) + + return +} + +type basicRecoverAsync struct { + Requeue bool +} + +func (msg *basicRecoverAsync) id() (uint16, uint16) { + return 60, 100 +} + +func (msg *basicRecoverAsync) wait() bool { + return false +} + +func (msg *basicRecoverAsync) write(w io.Writer) (err error) { + var bits byte + + if msg.Requeue { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicRecoverAsync) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Requeue = (bits&(1<<0) > 0) + + return +} + +type basicRecover struct { + Requeue bool +} + +func (msg *basicRecover) id() (uint16, uint16) { + return 60, 110 +} + +func (msg *basicRecover) wait() bool { + return true +} + +func (msg *basicRecover) write(w io.Writer) (err error) { + var bits byte + + if msg.Requeue { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicRecover) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Requeue = (bits&(1<<0) > 0) + + return +} + +type basicRecoverOk struct { +} + +func (msg *basicRecoverOk) id() (uint16, uint16) { + return 60, 111 +} + +func (msg *basicRecoverOk) wait() bool { + return true +} + +func (msg *basicRecoverOk) write(w io.Writer) (err error) { + + return +} + +func (msg *basicRecoverOk) read(r io.Reader) (err error) { + + return +} + +type basicNack struct { + DeliveryTag uint64 + Multiple bool + Requeue bool +} + +func (msg *basicNack) id() (uint16, uint16) { + return 60, 120 +} + +func (msg *basicNack) wait() bool { + return false +} + +func (msg *basicNack) write(w io.Writer) (err error) { + var bits byte + + if err = binary.Write(w, binary.BigEndian, msg.DeliveryTag); err != nil { + return + } + + if msg.Multiple { + bits |= 1 << 0 + } + + if msg.Requeue { + bits |= 1 << 1 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *basicNack) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &msg.DeliveryTag); err != nil { + return + } + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Multiple = (bits&(1<<0) > 0) + msg.Requeue = (bits&(1<<1) > 0) + + return +} + +type txSelect struct { +} + +func (msg *txSelect) id() (uint16, uint16) { + return 90, 10 +} + +func (msg *txSelect) wait() bool { + return true +} + +func (msg *txSelect) write(w io.Writer) (err error) { + + return +} + +func (msg *txSelect) read(r io.Reader) (err error) { + + return +} + +type txSelectOk struct { +} + +func (msg *txSelectOk) id() (uint16, uint16) { + return 90, 11 +} + +func (msg *txSelectOk) wait() bool { + return true +} + +func (msg *txSelectOk) write(w io.Writer) (err error) { + + return +} + +func (msg *txSelectOk) read(r io.Reader) (err error) { + + return +} + +type txCommit struct { +} + +func (msg *txCommit) id() (uint16, uint16) { + return 90, 20 +} + +func (msg *txCommit) wait() bool { + return true +} + +func (msg *txCommit) write(w io.Writer) (err error) { + + return +} + +func (msg *txCommit) read(r io.Reader) (err error) { + + return +} + +type txCommitOk struct { +} + +func (msg *txCommitOk) id() (uint16, uint16) { + return 90, 21 +} + +func (msg *txCommitOk) wait() bool { + return true +} + +func (msg *txCommitOk) write(w io.Writer) (err error) { + + return +} + +func (msg *txCommitOk) read(r io.Reader) (err error) { + + return +} + +type txRollback struct { +} + +func (msg *txRollback) id() (uint16, uint16) { + return 90, 30 +} + +func (msg *txRollback) wait() bool { + return true +} + +func (msg *txRollback) write(w io.Writer) (err error) { + + return +} + +func (msg *txRollback) read(r io.Reader) (err error) { + + return +} + +type txRollbackOk struct { +} + +func (msg *txRollbackOk) id() (uint16, uint16) { + return 90, 31 +} + +func (msg *txRollbackOk) wait() bool { + return true +} + +func (msg *txRollbackOk) write(w io.Writer) (err error) { + + return +} + +func (msg *txRollbackOk) read(r io.Reader) (err error) { + + return +} + +type confirmSelect struct { + Nowait bool +} + +func (msg *confirmSelect) id() (uint16, uint16) { + return 85, 10 +} + +func (msg *confirmSelect) wait() bool { + return true +} + +func (msg *confirmSelect) write(w io.Writer) (err error) { + var bits byte + + if msg.Nowait { + bits |= 1 << 0 + } + + if err = binary.Write(w, binary.BigEndian, bits); err != nil { + return + } + + return +} + +func (msg *confirmSelect) read(r io.Reader) (err error) { + var bits byte + + if err = binary.Read(r, binary.BigEndian, &bits); err != nil { + return + } + msg.Nowait = (bits&(1<<0) > 0) + + return +} + +type confirmSelectOk struct { +} + +func (msg *confirmSelectOk) id() (uint16, uint16) { + return 85, 11 +} + +func (msg *confirmSelectOk) wait() bool { + return true +} + +func (msg *confirmSelectOk) write(w io.Writer) (err error) { + + return +} + +func (msg *confirmSelectOk) read(r io.Reader) (err error) { + + return +} + +func (r *reader) parseMethodFrame(channel uint16, size uint32) (f frame, err error) { + mf := &methodFrame{ + ChannelId: channel, + } + + if err = binary.Read(r.r, binary.BigEndian, &mf.ClassId); err != nil { + return + } + + if err = binary.Read(r.r, binary.BigEndian, &mf.MethodId); err != nil { + return + } + + switch mf.ClassId { + + case 10: // connection + switch mf.MethodId { + + case 10: // connection start + //fmt.Println("NextMethod: class:10 method:10") + method := &connectionStart{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 11: // connection start-ok + //fmt.Println("NextMethod: class:10 method:11") + method := &connectionStartOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 20: // connection secure + //fmt.Println("NextMethod: class:10 method:20") + method := &connectionSecure{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 21: // connection secure-ok + //fmt.Println("NextMethod: class:10 method:21") + method := &connectionSecureOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 30: // connection tune + //fmt.Println("NextMethod: class:10 method:30") + method := &connectionTune{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 31: // connection tune-ok + //fmt.Println("NextMethod: class:10 method:31") + method := &connectionTuneOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 40: // connection open + //fmt.Println("NextMethod: class:10 method:40") + method := &connectionOpen{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 41: // connection open-ok + //fmt.Println("NextMethod: class:10 method:41") + method := &connectionOpenOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 50: // connection close + //fmt.Println("NextMethod: class:10 method:50") + method := &connectionClose{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 51: // connection close-ok + //fmt.Println("NextMethod: class:10 method:51") + method := &connectionCloseOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 60: // connection blocked + //fmt.Println("NextMethod: class:10 method:60") + method := &connectionBlocked{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 61: // connection unblocked + //fmt.Println("NextMethod: class:10 method:61") + method := &connectionUnblocked{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 70: // connection update-secret + //fmt.Println("NextMethod: class:10 method:70") + method := &connectionUpdateSecret{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 71: // connection update-secret-ok + //fmt.Println("NextMethod: class:10 method:71") + method := &connectionUpdateSecretOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + default: + return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId) + } + + case 20: // channel + switch mf.MethodId { + + case 10: // channel open + //fmt.Println("NextMethod: class:20 method:10") + method := &channelOpen{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 11: // channel open-ok + //fmt.Println("NextMethod: class:20 method:11") + method := &channelOpenOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 20: // channel flow + //fmt.Println("NextMethod: class:20 method:20") + method := &channelFlow{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 21: // channel flow-ok + //fmt.Println("NextMethod: class:20 method:21") + method := &channelFlowOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 40: // channel close + //fmt.Println("NextMethod: class:20 method:40") + method := &channelClose{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 41: // channel close-ok + //fmt.Println("NextMethod: class:20 method:41") + method := &channelCloseOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + default: + return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId) + } + + case 40: // exchange + switch mf.MethodId { + + case 10: // exchange declare + //fmt.Println("NextMethod: class:40 method:10") + method := &exchangeDeclare{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 11: // exchange declare-ok + //fmt.Println("NextMethod: class:40 method:11") + method := &exchangeDeclareOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 20: // exchange delete + //fmt.Println("NextMethod: class:40 method:20") + method := &exchangeDelete{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 21: // exchange delete-ok + //fmt.Println("NextMethod: class:40 method:21") + method := &exchangeDeleteOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 30: // exchange bind + //fmt.Println("NextMethod: class:40 method:30") + method := &exchangeBind{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 31: // exchange bind-ok + //fmt.Println("NextMethod: class:40 method:31") + method := &exchangeBindOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 40: // exchange unbind + //fmt.Println("NextMethod: class:40 method:40") + method := &exchangeUnbind{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 51: // exchange unbind-ok + //fmt.Println("NextMethod: class:40 method:51") + method := &exchangeUnbindOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + default: + return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId) + } + + case 50: // queue + switch mf.MethodId { + + case 10: // queue declare + //fmt.Println("NextMethod: class:50 method:10") + method := &queueDeclare{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 11: // queue declare-ok + //fmt.Println("NextMethod: class:50 method:11") + method := &queueDeclareOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 20: // queue bind + //fmt.Println("NextMethod: class:50 method:20") + method := &queueBind{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 21: // queue bind-ok + //fmt.Println("NextMethod: class:50 method:21") + method := &queueBindOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 50: // queue unbind + //fmt.Println("NextMethod: class:50 method:50") + method := &queueUnbind{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 51: // queue unbind-ok + //fmt.Println("NextMethod: class:50 method:51") + method := &queueUnbindOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 30: // queue purge + //fmt.Println("NextMethod: class:50 method:30") + method := &queuePurge{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 31: // queue purge-ok + //fmt.Println("NextMethod: class:50 method:31") + method := &queuePurgeOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 40: // queue delete + //fmt.Println("NextMethod: class:50 method:40") + method := &queueDelete{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 41: // queue delete-ok + //fmt.Println("NextMethod: class:50 method:41") + method := &queueDeleteOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + default: + return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId) + } + + case 60: // basic + switch mf.MethodId { + + case 10: // basic qos + //fmt.Println("NextMethod: class:60 method:10") + method := &basicQos{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 11: // basic qos-ok + //fmt.Println("NextMethod: class:60 method:11") + method := &basicQosOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 20: // basic consume + //fmt.Println("NextMethod: class:60 method:20") + method := &basicConsume{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 21: // basic consume-ok + //fmt.Println("NextMethod: class:60 method:21") + method := &basicConsumeOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 30: // basic cancel + //fmt.Println("NextMethod: class:60 method:30") + method := &basicCancel{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 31: // basic cancel-ok + //fmt.Println("NextMethod: class:60 method:31") + method := &basicCancelOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 40: // basic publish + //fmt.Println("NextMethod: class:60 method:40") + method := &basicPublish{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 50: // basic return + //fmt.Println("NextMethod: class:60 method:50") + method := &basicReturn{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 60: // basic deliver + //fmt.Println("NextMethod: class:60 method:60") + method := &basicDeliver{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 70: // basic get + //fmt.Println("NextMethod: class:60 method:70") + method := &basicGet{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 71: // basic get-ok + //fmt.Println("NextMethod: class:60 method:71") + method := &basicGetOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 72: // basic get-empty + //fmt.Println("NextMethod: class:60 method:72") + method := &basicGetEmpty{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 80: // basic ack + //fmt.Println("NextMethod: class:60 method:80") + method := &basicAck{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 90: // basic reject + //fmt.Println("NextMethod: class:60 method:90") + method := &basicReject{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 100: // basic recover-async + //fmt.Println("NextMethod: class:60 method:100") + method := &basicRecoverAsync{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 110: // basic recover + //fmt.Println("NextMethod: class:60 method:110") + method := &basicRecover{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 111: // basic recover-ok + //fmt.Println("NextMethod: class:60 method:111") + method := &basicRecoverOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 120: // basic nack + //fmt.Println("NextMethod: class:60 method:120") + method := &basicNack{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + default: + return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId) + } + + case 90: // tx + switch mf.MethodId { + + case 10: // tx select + //fmt.Println("NextMethod: class:90 method:10") + method := &txSelect{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 11: // tx select-ok + //fmt.Println("NextMethod: class:90 method:11") + method := &txSelectOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 20: // tx commit + //fmt.Println("NextMethod: class:90 method:20") + method := &txCommit{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 21: // tx commit-ok + //fmt.Println("NextMethod: class:90 method:21") + method := &txCommitOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 30: // tx rollback + //fmt.Println("NextMethod: class:90 method:30") + method := &txRollback{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 31: // tx rollback-ok + //fmt.Println("NextMethod: class:90 method:31") + method := &txRollbackOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + default: + return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId) + } + + case 85: // confirm + switch mf.MethodId { + + case 10: // confirm select + //fmt.Println("NextMethod: class:85 method:10") + method := &confirmSelect{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + case 11: // confirm select-ok + //fmt.Println("NextMethod: class:85 method:11") + method := &confirmSelectOk{} + if err = method.read(r.r); err != nil { + return + } + mf.Method = method + + default: + return nil, fmt.Errorf("Bad method frame, unknown method %d for class %d", mf.MethodId, mf.ClassId) + } + + default: + return nil, fmt.Errorf("Bad method frame, unknown class %d", mf.ClassId) + } + + return mf, nil +} diff --git a/vendor/github.com/rabbitmq/amqp091-go/types.go b/vendor/github.com/rabbitmq/amqp091-go/types.go new file mode 100644 index 0000000000..e8d8986a69 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/types.go @@ -0,0 +1,514 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All Rights Reserved. +// Copyright (c) 2012-2021, Sean Treadway, SoundCloud Ltd. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package amqp091 + +import ( + "fmt" + "io" + "time" +) + +const DefaultExchange = "" + +// Constants for standard AMQP 0-9-1 exchange types. +const ( + ExchangeDirect = "direct" + ExchangeFanout = "fanout" + ExchangeTopic = "topic" + ExchangeHeaders = "headers" +) + +var ( + // ErrClosed is returned when the channel or connection is not open + ErrClosed = &Error{Code: ChannelError, Reason: "channel/connection is not open"} + + // ErrChannelMax is returned when Connection.Channel has been called enough + // times that all channel IDs have been exhausted in the client or the + // server. + ErrChannelMax = &Error{Code: ChannelError, Reason: "channel id space exhausted"} + + // ErrSASL is returned from Dial when the authentication mechanism could not + // be negotiated. + ErrSASL = &Error{Code: AccessRefused, Reason: "SASL could not negotiate a shared mechanism"} + + // ErrCredentials is returned when the authenticated client is not authorized + // to any vhost. + ErrCredentials = &Error{Code: AccessRefused, Reason: "username or password not allowed"} + + // ErrVhost is returned when the authenticated user is not permitted to + // access the requested Vhost. + ErrVhost = &Error{Code: AccessRefused, Reason: "no access to this vhost"} + + // ErrSyntax is hard protocol error, indicating an unsupported protocol, + // implementation or encoding. + ErrSyntax = &Error{Code: SyntaxError, Reason: "invalid field or value inside of a frame"} + + // ErrFrame is returned when the protocol frame cannot be read from the + // server, indicating an unsupported protocol or unsupported frame type. + ErrFrame = &Error{Code: FrameError, Reason: "frame could not be parsed"} + + // ErrCommandInvalid is returned when the server sends an unexpected response + // to this requested message type. This indicates a bug in this client. + ErrCommandInvalid = &Error{Code: CommandInvalid, Reason: "unexpected command received"} + + // ErrUnexpectedFrame is returned when something other than a method or + // heartbeat frame is delivered to the Connection, indicating a bug in the + // client. + ErrUnexpectedFrame = &Error{Code: UnexpectedFrame, Reason: "unexpected frame received"} + + // ErrFieldType is returned when writing a message containing a Go type unsupported by AMQP. + ErrFieldType = &Error{Code: SyntaxError, Reason: "unsupported table field type"} +) + +// internal errors used inside the library +var ( + errInvalidTypeAssertion = &Error{Code: InternalError, Reason: "type assertion unsuccessful", Server: false, Recover: true} +) + +// Error captures the code and reason a channel or connection has been closed +// by the server. +type Error struct { + Code int // constant code from the specification + Reason string // description of the error + Server bool // true when initiated from the server, false when from this library + Recover bool // true when this error can be recovered by retrying later or with different parameters +} + +func newError(code uint16, text string) *Error { + return &Error{ + Code: int(code), + Reason: text, + Recover: isSoftExceptionCode(int(code)), + Server: true, + } +} + +func (e Error) Error() string { + return fmt.Sprintf("Exception (%d) Reason: %q", e.Code, e.Reason) +} + +// Used by header frames to capture routing and header information +type properties struct { + ContentType string // MIME content type + ContentEncoding string // MIME content encoding + Headers Table // Application or header exchange table + DeliveryMode uint8 // queue implementation use - Transient (1) or Persistent (2) + Priority uint8 // queue implementation use - 0 to 9 + CorrelationId string // application use - correlation identifier + ReplyTo string // application use - address to to reply to (ex: RPC) + Expiration string // implementation use - message expiration spec + MessageId string // application use - message identifier + Timestamp time.Time // application use - message timestamp + Type string // application use - message type name + UserId string // application use - creating user id + AppId string // application use - creating application + reserved1 string // was cluster-id - process for buffer consumption +} + +// DeliveryMode. Transient means higher throughput but messages will not be +// restored on broker restart. The delivery mode of publishings is unrelated +// to the durability of the queues they reside on. Transient messages will +// not be restored to durable queues, persistent messages will be restored to +// durable queues and lost on non-durable queues during server restart. +// +// This remains typed as uint8 to match Publishing.DeliveryMode. Other +// delivery modes specific to custom queue implementations are not enumerated +// here. +const ( + Transient uint8 = 1 + Persistent uint8 = 2 +) + +// The property flags are an array of bits that indicate the presence or +// absence of each property value in sequence. The bits are ordered from most +// high to low - bit 15 indicates the first property. +const ( + flagContentType = 0x8000 + flagContentEncoding = 0x4000 + flagHeaders = 0x2000 + flagDeliveryMode = 0x1000 + flagPriority = 0x0800 + flagCorrelationId = 0x0400 + flagReplyTo = 0x0200 + flagExpiration = 0x0100 + flagMessageId = 0x0080 + flagTimestamp = 0x0040 + flagType = 0x0020 + flagUserId = 0x0010 + flagAppId = 0x0008 + flagReserved1 = 0x0004 +) + +// Queue captures the current server state of the queue on the server returned +// from Channel.QueueDeclare or Channel.QueueInspect. +type Queue struct { + Name string // server confirmed or generated name + Messages int // count of messages not awaiting acknowledgment + Consumers int // number of consumers receiving deliveries +} + +// Publishing captures the client message sent to the server. The fields +// outside of the Headers table included in this struct mirror the underlying +// fields in the content frame. They use native types for convenience and +// efficiency. +type Publishing struct { + // Application or exchange specific fields, + // the headers exchange will inspect this field. + Headers Table + + // Properties + ContentType string // MIME content type + ContentEncoding string // MIME content encoding + DeliveryMode uint8 // Transient (0 or 1) or Persistent (2) + Priority uint8 // 0 to 9 + CorrelationId string // correlation identifier + ReplyTo string // address to to reply to (ex: RPC) + Expiration string // message expiration spec + MessageId string // message identifier + Timestamp time.Time // message timestamp + Type string // message type name + UserId string // creating user id - ex: "guest" + AppId string // creating application id + + // The application specific payload of the message + Body []byte +} + +// Blocking notifies the server's TCP flow control of the Connection. When a +// server hits a memory or disk alarm it will block all connections until the +// resources are reclaimed. Use NotifyBlock on the Connection to receive these +// events. +type Blocking struct { + Active bool // TCP pushback active/inactive on server + Reason string // Server reason for activation +} + +// DeferredConfirmation represents a future publisher confirm for a message. It +// allows users to directly correlate a publishing to a confirmation. These are +// returned from PublishWithDeferredConfirm on Channels. +type DeferredConfirmation struct { + DeliveryTag uint64 + + done chan struct{} + ack bool +} + +// Confirmation notifies the acknowledgment or negative acknowledgement of a +// publishing identified by its delivery tag. Use NotifyPublish on the Channel +// to consume these events. +type Confirmation struct { + DeliveryTag uint64 // A 1 based counter of publishings from when the channel was put in Confirm mode + Ack bool // True when the server successfully received the publishing +} + +// Decimal matches the AMQP decimal type. Scale is the number of decimal +// digits Scale == 2, Value == 12345, Decimal == 123.45 +type Decimal struct { + Scale uint8 + Value int32 +} + +// Most common queue argument keys in queue declaration. For a comprehensive list +// of queue arguments, visit [RabbitMQ Queue docs]. +// +// QueueTypeArg queue argument is used to declare quorum and stream queues. +// Accepted values are QueueTypeClassic (default), QueueTypeQuorum and +// QueueTypeStream. [Quorum Queues] accept (almost) all queue arguments as their +// Classic Queues counterparts. Check [feature comparison] docs for more +// information. +// +// Queues can define their [max length] using QueueMaxLenArg and +// QueueMaxLenBytesArg queue arguments. Overflow behaviour is set using +// QueueOverflowArg. Accepted values are QueueOverflowDropHead (default), +// QueueOverflowRejectPublish and QueueOverflowRejectPublishDLX. +// +// [Queue TTL] can be defined using QueueTTLArg. That is, the time-to-live for an +// unused queue. [Queue Message TTL] can be defined using QueueMessageTTLArg. +// This will set a time-to-live for **messages** in the queue. +// +// [Stream retention] can be configured using StreamMaxLenBytesArg, to set the +// maximum size of the stream. Please note that stream queues always keep, at +// least, one segment. [Stream retention] can also be set using StreamMaxAgeArg, +// to set time-based retention. Values are string with unit suffix. Valid +// suffixes are Y, M, D, h, m, s. E.g. "7D" for one week. The maximum segment +// size can be set using StreamMaxSegmentSizeBytesArg. The default value is +// 500_000_000 bytes ~= 500 megabytes +// +// [RabbitMQ Queue docs]: https://rabbitmq.com/queues.html +// [Stream retention]: https://rabbitmq.com/streams.html#retention +// [max length]: https://rabbitmq.com/maxlength.html +// [Queue TTL]: https://rabbitmq.com/ttl.html#queue-ttl +// [Queue Message TTL]: https://rabbitmq.com/ttl.html#per-queue-message-ttl +// [Quorum Queues]: https://rabbitmq.com/quorum-queues.html +// [feature comparison]: https://rabbitmq.com/quorum-queues.html#feature-comparison +const ( + QueueTypeArg = "x-queue-type" + QueueMaxLenArg = "x-max-length" + QueueMaxLenBytesArg = "x-max-length-bytes" + StreamMaxLenBytesArg = "x-max-length-bytes" + QueueOverflowArg = "x-overflow" + QueueMessageTTLArg = "x-message-ttl" + QueueTTLArg = "x-expires" + StreamMaxAgeArg = "x-max-age" + StreamMaxSegmentSizeBytesArg = "x-stream-max-segment-size-bytes" +) + +// Values for queue arguments. Use as values for queue arguments during queue declaration. +// The following argument table will create a classic queue, with max length set to 100 messages, +// and a queue TTL of 30 minutes. +// +// args := amqp.Table{ +// amqp.QueueTypeArg: QueueTypeClassic, +// amqp.QueueMaxLenArg: 100, +// amqp.QueueTTLArg: 1800000, +// } +const ( + QueueTypeClassic = "classic" + QueueTypeQuorum = "quorum" + QueueTypeStream = "stream" + QueueOverflowDropHead = "drop-head" + QueueOverflowRejectPublish = "reject-publish" + QueueOverflowRejectPublishDLX = "reject-publish-dlx" +) + +// Table stores user supplied fields of the following types: +// +// bool +// byte +// int8 +// float32 +// float64 +// int +// int16 +// int32 +// int64 +// nil +// string +// time.Time +// amqp.Decimal +// amqp.Table +// []byte +// []interface{} - containing above types +// +// Functions taking a table will immediately fail when the table contains a +// value of an unsupported type. +// +// The caller must be specific in which precision of integer it wishes to +// encode. +// +// Use a type assertion when reading values from a table for type conversion. +// +// RabbitMQ expects int32 for integer values. +type Table map[string]interface{} + +func validateField(f interface{}) error { + switch fv := f.(type) { + case nil, bool, byte, int8, int, int16, int32, int64, float32, float64, string, []byte, Decimal, time.Time: + return nil + + case []interface{}: + for _, v := range fv { + if err := validateField(v); err != nil { + return fmt.Errorf("in array %s", err) + } + } + return nil + + case Table: + for k, v := range fv { + if err := validateField(v); err != nil { + return fmt.Errorf("table field %q %s", k, err) + } + } + return nil + } + + return fmt.Errorf("value %T not supported", f) +} + +// Validate returns and error if any Go types in the table are incompatible with AMQP types. +func (t Table) Validate() error { + return validateField(t) +} + +// Sets the connection name property. This property can be used in +// amqp.Config to set a custom connection name during amqp.DialConfig(). This +// can be helpful to identify specific connections in RabbitMQ, for debugging or +// tracing purposes. +func (t Table) SetClientConnectionName(connName string) { + t["connection_name"] = connName +} + +type message interface { + id() (uint16, uint16) + wait() bool + read(io.Reader) error + write(io.Writer) error +} + +type messageWithContent interface { + message + getContent() (properties, []byte) + setContent(properties, []byte) +} + +/* +The base interface implemented as: + +2.3.5 frame Details + +All frames consist of a header (7 octets), a payload of arbitrary size, and a 'frame-end' octet that detects +malformed frames: + + 0 1 3 7 size+7 size+8 + +------+---------+-------------+ +------------+ +-----------+ + | type | channel | size | | payload | | frame-end | + +------+---------+-------------+ +------------+ +-----------+ + octet short long size octets octet + +To read a frame, we: + + 1. Read the header and check the frame type and channel. + 2. Depending on the frame type, we read the payload and process it. + 3. Read the frame end octet. + +In realistic implementations where performance is a concern, we would use +“read-ahead buffering” or “gathering reads” to avoid doing three separate +system calls to read a frame. +*/ +type frame interface { + write(io.Writer) error + channel() uint16 +} + +/* +Perform any updates on the channel immediately after the frame is decoded while the +connection mutex is held. +*/ +func updateChannel(f frame, channel *Channel) { + if mf, isMethodFrame := f.(*methodFrame); isMethodFrame { + if _, isChannelClose := mf.Method.(*channelClose); isChannelClose { + channel.setClosed() + } + } +} + +type reader struct { + r io.Reader +} + +type writer struct { + w io.Writer +} + +// Implements the frame interface for Connection RPC +type protocolHeader struct{} + +func (protocolHeader) write(w io.Writer) error { + _, err := w.Write([]byte{'A', 'M', 'Q', 'P', 0, 0, 9, 1}) + return err +} + +func (protocolHeader) channel() uint16 { + panic("only valid as initial handshake") +} + +/* +Method frames carry the high-level protocol commands (which we call "methods"). +One method frame carries one command. The method frame payload has this format: + + 0 2 4 + +----------+-----------+-------------- - - + | class-id | method-id | arguments... + +----------+-----------+-------------- - - + short short ... + +To process a method frame, we: + 1. Read the method frame payload. + 2. Unpack it into a structure. A given method always has the same structure, + so we can unpack the method rapidly. 3. Check that the method is allowed in + the current context. + 4. Check that the method arguments are valid. + 5. Execute the method. + +Method frame bodies are constructed as a list of AMQP data fields (bits, +integers, strings and string tables). The marshalling code is trivially +generated directly from the protocol specifications, and can be very rapid. +*/ +type methodFrame struct { + ChannelId uint16 + ClassId uint16 + MethodId uint16 + Method message +} + +func (f *methodFrame) channel() uint16 { return f.ChannelId } + +/* +Heartbeating is a technique designed to undo one of TCP/IP's features, namely +its ability to recover from a broken physical connection by closing only after +a quite long time-out. In some scenarios we need to know very rapidly if a +peer is disconnected or not responding for other reasons (e.g. it is looping). +Since heartbeating can be done at a low level, we implement this as a special +type of frame that peers exchange at the transport level, rather than as a +class method. +*/ +type heartbeatFrame struct { + ChannelId uint16 +} + +func (f *heartbeatFrame) channel() uint16 { return f.ChannelId } + +/* +Certain methods (such as Basic.Publish, Basic.Deliver, etc.) are formally +defined as carrying content. When a peer sends such a method frame, it always +follows it with a content header and zero or more content body frames. + +A content header frame has this format: + + 0 2 4 12 14 + +----------+--------+-----------+----------------+------------- - - + | class-id | weight | body size | property flags | property list... + +----------+--------+-----------+----------------+------------- - - + short short long long short remainder... + +We place content body in distinct frames (rather than including it in the +method) so that AMQP may support "zero copy" techniques in which content is +never marshalled or encoded. We place the content properties in their own +frame so that recipients can selectively discard contents they do not want to +process +*/ +type headerFrame struct { + ChannelId uint16 + ClassId uint16 + weight uint16 + Size uint64 + Properties properties +} + +func (f *headerFrame) channel() uint16 { return f.ChannelId } + +/* +Content is the application data we carry from client-to-client via the AMQP +server. Content is, roughly speaking, a set of properties plus a binary data +part. The set of allowed properties are defined by the Basic class, and these +form the "content header frame". The data can be any size, and MAY be broken +into several (or many) chunks, each forming a "content body frame". + +Looking at the frames for a specific channel, as they pass on the wire, we +might see something like this: + + [method] + [method] [header] [body] [body] + [method] + ... +*/ +type bodyFrame struct { + ChannelId uint16 + Body []byte +} + +func (f *bodyFrame) channel() uint16 { return f.ChannelId } diff --git a/vendor/github.com/rabbitmq/amqp091-go/write.go b/vendor/github.com/rabbitmq/amqp091-go/write.go new file mode 100644 index 0000000000..d0011f86c4 --- /dev/null +++ b/vendor/github.com/rabbitmq/amqp091-go/write.go @@ -0,0 +1,427 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All Rights Reserved. +// Copyright (c) 2012-2021, Sean Treadway, SoundCloud Ltd. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package amqp091 + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "io" + "math" + "time" +) + +func (w *writer) WriteFrameNoFlush(frame frame) (err error) { + err = frame.write(w.w) + return +} + +func (w *writer) WriteFrame(frame frame) (err error) { + if err = frame.write(w.w); err != nil { + return + } + + if buf, ok := w.w.(*bufio.Writer); ok { + err = buf.Flush() + } + + return +} + +func (f *methodFrame) write(w io.Writer) (err error) { + var payload bytes.Buffer + + if f.Method == nil { + return errors.New("malformed frame: missing method") + } + + class, method := f.Method.id() + + if err = binary.Write(&payload, binary.BigEndian, class); err != nil { + return + } + + if err = binary.Write(&payload, binary.BigEndian, method); err != nil { + return + } + + if err = f.Method.write(&payload); err != nil { + return + } + + return writeFrame(w, frameMethod, f.ChannelId, payload.Bytes()) +} + +// Heartbeat +// +// Payload is empty +func (f *heartbeatFrame) write(w io.Writer) (err error) { + return writeFrame(w, frameHeartbeat, f.ChannelId, []byte{}) +} + +// CONTENT HEADER +// 0 2 4 12 14 +// +----------+--------+-----------+----------------+------------- - - +// | class-id | weight | body size | property flags | property list... +// +----------+--------+-----------+----------------+------------- - - +// +// short short long long short remainder... +func (f *headerFrame) write(w io.Writer) (err error) { + var payload bytes.Buffer + var zeroTime time.Time + + if err = binary.Write(&payload, binary.BigEndian, f.ClassId); err != nil { + return + } + + if err = binary.Write(&payload, binary.BigEndian, f.weight); err != nil { + return + } + + if err = binary.Write(&payload, binary.BigEndian, f.Size); err != nil { + return + } + + // First pass will build the mask to be serialized, second pass will serialize + // each of the fields that appear in the mask. + + var mask uint16 + + if len(f.Properties.ContentType) > 0 { + mask = mask | flagContentType + } + if len(f.Properties.ContentEncoding) > 0 { + mask = mask | flagContentEncoding + } + if f.Properties.Headers != nil && len(f.Properties.Headers) > 0 { + mask = mask | flagHeaders + } + if f.Properties.DeliveryMode > 0 { + mask = mask | flagDeliveryMode + } + if f.Properties.Priority > 0 { + mask = mask | flagPriority + } + if len(f.Properties.CorrelationId) > 0 { + mask = mask | flagCorrelationId + } + if len(f.Properties.ReplyTo) > 0 { + mask = mask | flagReplyTo + } + if len(f.Properties.Expiration) > 0 { + mask = mask | flagExpiration + } + if len(f.Properties.MessageId) > 0 { + mask = mask | flagMessageId + } + if f.Properties.Timestamp != zeroTime { + mask = mask | flagTimestamp + } + if len(f.Properties.Type) > 0 { + mask = mask | flagType + } + if len(f.Properties.UserId) > 0 { + mask = mask | flagUserId + } + if len(f.Properties.AppId) > 0 { + mask = mask | flagAppId + } + + if err = binary.Write(&payload, binary.BigEndian, mask); err != nil { + return + } + + if hasProperty(mask, flagContentType) { + if err = writeShortstr(&payload, f.Properties.ContentType); err != nil { + return + } + } + if hasProperty(mask, flagContentEncoding) { + if err = writeShortstr(&payload, f.Properties.ContentEncoding); err != nil { + return + } + } + if hasProperty(mask, flagHeaders) { + if err = writeTable(&payload, f.Properties.Headers); err != nil { + return + } + } + if hasProperty(mask, flagDeliveryMode) { + if err = binary.Write(&payload, binary.BigEndian, f.Properties.DeliveryMode); err != nil { + return + } + } + if hasProperty(mask, flagPriority) { + if err = binary.Write(&payload, binary.BigEndian, f.Properties.Priority); err != nil { + return + } + } + if hasProperty(mask, flagCorrelationId) { + if err = writeShortstr(&payload, f.Properties.CorrelationId); err != nil { + return + } + } + if hasProperty(mask, flagReplyTo) { + if err = writeShortstr(&payload, f.Properties.ReplyTo); err != nil { + return + } + } + if hasProperty(mask, flagExpiration) { + if err = writeShortstr(&payload, f.Properties.Expiration); err != nil { + return + } + } + if hasProperty(mask, flagMessageId) { + if err = writeShortstr(&payload, f.Properties.MessageId); err != nil { + return + } + } + if hasProperty(mask, flagTimestamp) { + if err = binary.Write(&payload, binary.BigEndian, uint64(f.Properties.Timestamp.Unix())); err != nil { + return + } + } + if hasProperty(mask, flagType) { + if err = writeShortstr(&payload, f.Properties.Type); err != nil { + return + } + } + if hasProperty(mask, flagUserId) { + if err = writeShortstr(&payload, f.Properties.UserId); err != nil { + return + } + } + if hasProperty(mask, flagAppId) { + if err = writeShortstr(&payload, f.Properties.AppId); err != nil { + return + } + } + + return writeFrame(w, frameHeader, f.ChannelId, payload.Bytes()) +} + +// Body +// +// Payload is one byterange from the full body who's size is declared in the +// Header frame +func (f *bodyFrame) write(w io.Writer) (err error) { + return writeFrame(w, frameBody, f.ChannelId, f.Body) +} + +func writeFrame(w io.Writer, typ uint8, channel uint16, payload []byte) (err error) { + end := []byte{frameEnd} + size := uint(len(payload)) + + _, err = w.Write([]byte{ + typ, + byte((channel & 0xff00) >> 8), + byte((channel & 0x00ff) >> 0), + byte((size & 0xff000000) >> 24), + byte((size & 0x00ff0000) >> 16), + byte((size & 0x0000ff00) >> 8), + byte((size & 0x000000ff) >> 0), + }) + + if err != nil { + return + } + + if _, err = w.Write(payload); err != nil { + return + } + + if _, err = w.Write(end); err != nil { + return + } + + return +} + +func writeShortstr(w io.Writer, s string) (err error) { + b := []byte(s) + + var length = uint8(len(b)) + + if err = binary.Write(w, binary.BigEndian, length); err != nil { + return + } + + if _, err = w.Write(b[:length]); err != nil { + return + } + + return +} + +func writeLongstr(w io.Writer, s string) (err error) { + b := []byte(s) + + var length = uint32(len(b)) + + if err = binary.Write(w, binary.BigEndian, length); err != nil { + return + } + + if _, err = w.Write(b[:length]); err != nil { + return + } + + return +} + +/* +'A': []interface{} +'D': Decimal +'F': Table +'I': int32 +'S': string +'T': time.Time +'V': nil +'b': int8 +'B': byte +'d': float64 +'f': float32 +'l': int64 +'s': int16 +'t': bool +'x': []byte +*/ +func writeField(w io.Writer, value interface{}) (err error) { + var buf [9]byte + var enc []byte + + switch v := value.(type) { + case bool: + buf[0] = 't' + if v { + buf[1] = byte(1) + } else { + buf[1] = byte(0) + } + enc = buf[:2] + + case byte: + buf[0] = 'B' + buf[1] = v + enc = buf[:2] + + case int8: + buf[0] = 'b' + buf[1] = uint8(v) + enc = buf[:2] + + case int16: + buf[0] = 's' + binary.BigEndian.PutUint16(buf[1:3], uint16(v)) + enc = buf[:3] + + case int: + buf[0] = 'I' + binary.BigEndian.PutUint32(buf[1:5], uint32(v)) + enc = buf[:5] + + case int32: + buf[0] = 'I' + binary.BigEndian.PutUint32(buf[1:5], uint32(v)) + enc = buf[:5] + + case int64: + buf[0] = 'l' + binary.BigEndian.PutUint64(buf[1:9], uint64(v)) + enc = buf[:9] + + case float32: + buf[0] = 'f' + binary.BigEndian.PutUint32(buf[1:5], math.Float32bits(v)) + enc = buf[:5] + + case float64: + buf[0] = 'd' + binary.BigEndian.PutUint64(buf[1:9], math.Float64bits(v)) + enc = buf[:9] + + case Decimal: + buf[0] = 'D' + buf[1] = v.Scale + binary.BigEndian.PutUint32(buf[2:6], uint32(v.Value)) + enc = buf[:6] + + case string: + buf[0] = 'S' + binary.BigEndian.PutUint32(buf[1:5], uint32(len(v))) + enc = append(buf[:5], []byte(v)...) + + case []interface{}: // field-array + buf[0] = 'A' + + sec := new(bytes.Buffer) + for _, val := range v { + if err = writeField(sec, val); err != nil { + return + } + } + + binary.BigEndian.PutUint32(buf[1:5], uint32(sec.Len())) + if _, err = w.Write(buf[:5]); err != nil { + return + } + + if _, err = w.Write(sec.Bytes()); err != nil { + return + } + + return + + case time.Time: + buf[0] = 'T' + binary.BigEndian.PutUint64(buf[1:9], uint64(v.Unix())) + enc = buf[:9] + + case Table: + if _, err = w.Write([]byte{'F'}); err != nil { + return + } + return writeTable(w, v) + + case []byte: + buf[0] = 'x' + binary.BigEndian.PutUint32(buf[1:5], uint32(len(v))) + if _, err = w.Write(buf[0:5]); err != nil { + return + } + if _, err = w.Write(v); err != nil { + return + } + return + + case nil: + buf[0] = 'V' + enc = buf[:1] + + default: + return ErrFieldType + } + + _, err = w.Write(enc) + + return +} + +func writeTable(w io.Writer, table Table) (err error) { + var buf bytes.Buffer + + for key, val := range table { + if err = writeShortstr(&buf, key); err != nil { + return + } + if err = writeField(&buf, val); err != nil { + return + } + } + + return writeLongstr(w, buf.String()) +} diff --git a/vendor/github.com/rubenv/sql-migrate/Dockerfile b/vendor/github.com/rubenv/sql-migrate/Dockerfile new file mode 100644 index 0000000000..cfa00f7eaf --- /dev/null +++ b/vendor/github.com/rubenv/sql-migrate/Dockerfile @@ -0,0 +1,25 @@ +ARG GO_VERSION=1.16.2 +ARG ALPINE_VERSION=3.12 + +### Vendor +FROM golang:${GO_VERSION} as vendor +COPY . /project +WORKDIR /project +RUN go mod tidy && go mod vendor + +### Build binary +FROM golang:${GO_VERSION} as build-binary +COPY . /project +COPY --from=vendor /project/vendor /project/vendor +WORKDIR /project +RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 GO111MODULE=on go build \ + -v \ + -mod vendor \ + -o /project/bin/sql-migrate \ + /project/sql-migrate + +### Image +FROM alpine:${ALPINE_VERSION} as image +COPY --from=build-binary /project/bin/sql-migrate /usr/local/bin/sql-migrate +RUN chmod +x /usr/local/bin/sql-migrate +ENTRYPOINT ["sql-migrate"] diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go new file mode 100644 index 0000000000..b5e22c498a --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_decoders.go @@ -0,0 +1,1807 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "time" + + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +var ( + defaultValueDecoders DefaultValueDecoders + errCannotTruncate = errors.New("float64 can only be truncated to an integer type when truncation is enabled") +) + +type decodeBinaryError struct { + subtype byte + typeName string +} + +func (d decodeBinaryError) Error() string { + return fmt.Sprintf("only binary values with subtype 0x00 or 0x02 can be decoded into %s, but got subtype %v", d.typeName, d.subtype) +} + +func newDefaultStructCodec() *StructCodec { + codec, err := NewStructCodec(DefaultStructTagParser) + if err != nil { + // This function is called from the codec registration path, so errors can't be propagated. If there's an error + // constructing the StructCodec, we panic to avoid losing it. + panic(fmt.Errorf("error creating default StructCodec: %v", err)) + } + return codec +} + +// DefaultValueDecoders is a namespace type for the default ValueDecoders used +// when creating a registry. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +type DefaultValueDecoders struct{} + +// RegisterDefaultDecoders will register the decoder methods attached to DefaultValueDecoders with +// the provided RegistryBuilder. +// +// There is no support for decoding map[string]interface{} because there is no decoder for +// interface{}, so users must either register this decoder themselves or use the +// EmptyInterfaceDecoder available in the bson package. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) RegisterDefaultDecoders(rb *RegistryBuilder) { + if rb == nil { + panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) + } + + intDecoder := decodeAdapter{dvd.IntDecodeValue, dvd.intDecodeType} + floatDecoder := decodeAdapter{dvd.FloatDecodeValue, dvd.floatDecodeType} + + rb. + RegisterTypeDecoder(tD, ValueDecoderFunc(dvd.DDecodeValue)). + RegisterTypeDecoder(tBinary, decodeAdapter{dvd.BinaryDecodeValue, dvd.binaryDecodeType}). + RegisterTypeDecoder(tUndefined, decodeAdapter{dvd.UndefinedDecodeValue, dvd.undefinedDecodeType}). + RegisterTypeDecoder(tDateTime, decodeAdapter{dvd.DateTimeDecodeValue, dvd.dateTimeDecodeType}). + RegisterTypeDecoder(tNull, decodeAdapter{dvd.NullDecodeValue, dvd.nullDecodeType}). + RegisterTypeDecoder(tRegex, decodeAdapter{dvd.RegexDecodeValue, dvd.regexDecodeType}). + RegisterTypeDecoder(tDBPointer, decodeAdapter{dvd.DBPointerDecodeValue, dvd.dBPointerDecodeType}). + RegisterTypeDecoder(tTimestamp, decodeAdapter{dvd.TimestampDecodeValue, dvd.timestampDecodeType}). + RegisterTypeDecoder(tMinKey, decodeAdapter{dvd.MinKeyDecodeValue, dvd.minKeyDecodeType}). + RegisterTypeDecoder(tMaxKey, decodeAdapter{dvd.MaxKeyDecodeValue, dvd.maxKeyDecodeType}). + RegisterTypeDecoder(tJavaScript, decodeAdapter{dvd.JavaScriptDecodeValue, dvd.javaScriptDecodeType}). + RegisterTypeDecoder(tSymbol, decodeAdapter{dvd.SymbolDecodeValue, dvd.symbolDecodeType}). + RegisterTypeDecoder(tByteSlice, defaultByteSliceCodec). + RegisterTypeDecoder(tTime, defaultTimeCodec). + RegisterTypeDecoder(tEmpty, defaultEmptyInterfaceCodec). + RegisterTypeDecoder(tCoreArray, defaultArrayCodec). + RegisterTypeDecoder(tOID, decodeAdapter{dvd.ObjectIDDecodeValue, dvd.objectIDDecodeType}). + RegisterTypeDecoder(tDecimal, decodeAdapter{dvd.Decimal128DecodeValue, dvd.decimal128DecodeType}). + RegisterTypeDecoder(tJSONNumber, decodeAdapter{dvd.JSONNumberDecodeValue, dvd.jsonNumberDecodeType}). + RegisterTypeDecoder(tURL, decodeAdapter{dvd.URLDecodeValue, dvd.urlDecodeType}). + RegisterTypeDecoder(tCoreDocument, ValueDecoderFunc(dvd.CoreDocumentDecodeValue)). + RegisterTypeDecoder(tCodeWithScope, decodeAdapter{dvd.CodeWithScopeDecodeValue, dvd.codeWithScopeDecodeType}). + RegisterDefaultDecoder(reflect.Bool, decodeAdapter{dvd.BooleanDecodeValue, dvd.booleanDecodeType}). + RegisterDefaultDecoder(reflect.Int, intDecoder). + RegisterDefaultDecoder(reflect.Int8, intDecoder). + RegisterDefaultDecoder(reflect.Int16, intDecoder). + RegisterDefaultDecoder(reflect.Int32, intDecoder). + RegisterDefaultDecoder(reflect.Int64, intDecoder). + RegisterDefaultDecoder(reflect.Uint, defaultUIntCodec). + RegisterDefaultDecoder(reflect.Uint8, defaultUIntCodec). + RegisterDefaultDecoder(reflect.Uint16, defaultUIntCodec). + RegisterDefaultDecoder(reflect.Uint32, defaultUIntCodec). + RegisterDefaultDecoder(reflect.Uint64, defaultUIntCodec). + RegisterDefaultDecoder(reflect.Float32, floatDecoder). + RegisterDefaultDecoder(reflect.Float64, floatDecoder). + RegisterDefaultDecoder(reflect.Array, ValueDecoderFunc(dvd.ArrayDecodeValue)). + RegisterDefaultDecoder(reflect.Map, defaultMapCodec). + RegisterDefaultDecoder(reflect.Slice, defaultSliceCodec). + RegisterDefaultDecoder(reflect.String, defaultStringCodec). + RegisterDefaultDecoder(reflect.Struct, newDefaultStructCodec()). + RegisterDefaultDecoder(reflect.Ptr, NewPointerCodec()). + RegisterTypeMapEntry(bsontype.Double, tFloat64). + RegisterTypeMapEntry(bsontype.String, tString). + RegisterTypeMapEntry(bsontype.Array, tA). + RegisterTypeMapEntry(bsontype.Binary, tBinary). + RegisterTypeMapEntry(bsontype.Undefined, tUndefined). + RegisterTypeMapEntry(bsontype.ObjectID, tOID). + RegisterTypeMapEntry(bsontype.Boolean, tBool). + RegisterTypeMapEntry(bsontype.DateTime, tDateTime). + RegisterTypeMapEntry(bsontype.Regex, tRegex). + RegisterTypeMapEntry(bsontype.DBPointer, tDBPointer). + RegisterTypeMapEntry(bsontype.JavaScript, tJavaScript). + RegisterTypeMapEntry(bsontype.Symbol, tSymbol). + RegisterTypeMapEntry(bsontype.CodeWithScope, tCodeWithScope). + RegisterTypeMapEntry(bsontype.Int32, tInt32). + RegisterTypeMapEntry(bsontype.Int64, tInt64). + RegisterTypeMapEntry(bsontype.Timestamp, tTimestamp). + RegisterTypeMapEntry(bsontype.Decimal128, tDecimal). + RegisterTypeMapEntry(bsontype.MinKey, tMinKey). + RegisterTypeMapEntry(bsontype.MaxKey, tMaxKey). + RegisterTypeMapEntry(bsontype.Type(0), tD). + RegisterTypeMapEntry(bsontype.EmbeddedDocument, tD). + RegisterHookDecoder(tValueUnmarshaler, ValueDecoderFunc(dvd.ValueUnmarshalerDecodeValue)). + RegisterHookDecoder(tUnmarshaler, ValueDecoderFunc(dvd.UnmarshalerDecodeValue)) +} + +// DDecodeValue is the ValueDecoderFunc for primitive.D instances. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || !val.CanSet() || val.Type() != tD { + return ValueDecoderError{Name: "DDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} + } + + switch vrType := vr.Type(); vrType { + case bsontype.Type(0), bsontype.EmbeddedDocument: + dc.Ancestor = tD + case bsontype.Null: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + default: + return fmt.Errorf("cannot decode %v into a primitive.D", vrType) + } + + dr, err := vr.ReadDocument() + if err != nil { + return err + } + + decoder, err := dc.LookupDecoder(tEmpty) + if err != nil { + return err + } + tEmptyTypeDecoder, _ := decoder.(typeDecoder) + + // Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance. + var elems primitive.D + if !val.IsNil() { + val.SetLen(0) + elems = val.Interface().(primitive.D) + } else { + elems = make(primitive.D, 0) + } + + for { + key, elemVr, err := dr.ReadElement() + if err == bsonrw.ErrEOD { + break + } else if err != nil { + return err + } + + // Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty. + elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false) + if err != nil { + return err + } + + elems = append(elems, primitive.E{Key: key, Value: elem.Interface()}) + } + + val.Set(reflect.ValueOf(elems)) + return nil +} + +func (dvd DefaultValueDecoders) booleanDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t.Kind() != reflect.Bool { + return emptyValue, ValueDecoderError{ + Name: "BooleanDecodeValue", + Kinds: []reflect.Kind{reflect.Bool}, + Received: reflect.Zero(t), + } + } + + var b bool + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Int32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + b = (i32 != 0) + case bsontype.Int64: + i64, err := vr.ReadInt64() + if err != nil { + return emptyValue, err + } + b = (i64 != 0) + case bsontype.Double: + f64, err := vr.ReadDouble() + if err != nil { + return emptyValue, err + } + b = (f64 != 0) + case bsontype.Boolean: + b, err = vr.ReadBoolean() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a boolean", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(b), nil +} + +// BooleanDecodeValue is the ValueDecoderFunc for bool types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) BooleanDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { + return ValueDecoderError{Name: "BooleanDecodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} + } + + elem, err := dvd.booleanDecodeType(dctx, vr, val.Type()) + if err != nil { + return err + } + + val.SetBool(elem.Bool()) + return nil +} + +func (DefaultValueDecoders) intDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + var i64 int64 + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Int32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + i64 = int64(i32) + case bsontype.Int64: + i64, err = vr.ReadInt64() + if err != nil { + return emptyValue, err + } + case bsontype.Double: + f64, err := vr.ReadDouble() + if err != nil { + return emptyValue, err + } + if !dc.Truncate && math.Floor(f64) != f64 { + return emptyValue, errCannotTruncate + } + if f64 > float64(math.MaxInt64) { + return emptyValue, fmt.Errorf("%g overflows int64", f64) + } + i64 = int64(f64) + case bsontype.Boolean: + b, err := vr.ReadBoolean() + if err != nil { + return emptyValue, err + } + if b { + i64 = 1 + } + case bsontype.Null: + if err = vr.ReadNull(); err != nil { + return emptyValue, err + } + case bsontype.Undefined: + if err = vr.ReadUndefined(); err != nil { + return emptyValue, err + } + default: + return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) + } + + switch t.Kind() { + case reflect.Int8: + if i64 < math.MinInt8 || i64 > math.MaxInt8 { + return emptyValue, fmt.Errorf("%d overflows int8", i64) + } + + return reflect.ValueOf(int8(i64)), nil + case reflect.Int16: + if i64 < math.MinInt16 || i64 > math.MaxInt16 { + return emptyValue, fmt.Errorf("%d overflows int16", i64) + } + + return reflect.ValueOf(int16(i64)), nil + case reflect.Int32: + if i64 < math.MinInt32 || i64 > math.MaxInt32 { + return emptyValue, fmt.Errorf("%d overflows int32", i64) + } + + return reflect.ValueOf(int32(i64)), nil + case reflect.Int64: + return reflect.ValueOf(i64), nil + case reflect.Int: + if int64(int(i64)) != i64 { // Can we fit this inside of an int + return emptyValue, fmt.Errorf("%d overflows int", i64) + } + + return reflect.ValueOf(int(i64)), nil + default: + return emptyValue, ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Received: reflect.Zero(t), + } + } +} + +// IntDecodeValue is the ValueDecoderFunc for int types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Received: val, + } + } + + elem, err := dvd.intDecodeType(dc, vr, val.Type()) + if err != nil { + return err + } + + val.SetInt(elem.Int()) + return nil +} + +// UintDecodeValue is the ValueDecoderFunc for uint types. +// +// Deprecated: UintDecodeValue is not registered by default. Use UintCodec.DecodeValue instead. +func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + var i64 int64 + var err error + switch vr.Type() { + case bsontype.Int32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + i64 = int64(i32) + case bsontype.Int64: + i64, err = vr.ReadInt64() + if err != nil { + return err + } + case bsontype.Double: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + if !dc.Truncate && math.Floor(f64) != f64 { + return errors.New("UintDecodeValue can only truncate float64 to an integer type when truncation is enabled") + } + if f64 > float64(math.MaxInt64) { + return fmt.Errorf("%g overflows int64", f64) + } + i64 = int64(f64) + case bsontype.Boolean: + b, err := vr.ReadBoolean() + if err != nil { + return err + } + if b { + i64 = 1 + } + default: + return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) + } + + if !val.CanSet() { + return ValueDecoderError{ + Name: "UintDecodeValue", + Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Received: val, + } + } + + switch val.Kind() { + case reflect.Uint8: + if i64 < 0 || i64 > math.MaxUint8 { + return fmt.Errorf("%d overflows uint8", i64) + } + case reflect.Uint16: + if i64 < 0 || i64 > math.MaxUint16 { + return fmt.Errorf("%d overflows uint16", i64) + } + case reflect.Uint32: + if i64 < 0 || i64 > math.MaxUint32 { + return fmt.Errorf("%d overflows uint32", i64) + } + case reflect.Uint64: + if i64 < 0 { + return fmt.Errorf("%d overflows uint64", i64) + } + case reflect.Uint: + if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + return fmt.Errorf("%d overflows uint", i64) + } + default: + return ValueDecoderError{ + Name: "UintDecodeValue", + Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Received: val, + } + } + + val.SetUint(uint64(i64)) + return nil +} + +func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + var f float64 + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Int32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + f = float64(i32) + case bsontype.Int64: + i64, err := vr.ReadInt64() + if err != nil { + return emptyValue, err + } + f = float64(i64) + case bsontype.Double: + f, err = vr.ReadDouble() + if err != nil { + return emptyValue, err + } + case bsontype.Boolean: + b, err := vr.ReadBoolean() + if err != nil { + return emptyValue, err + } + if b { + f = 1 + } + case bsontype.Null: + if err = vr.ReadNull(); err != nil { + return emptyValue, err + } + case bsontype.Undefined: + if err = vr.ReadUndefined(); err != nil { + return emptyValue, err + } + default: + return emptyValue, fmt.Errorf("cannot decode %v into a float32 or float64 type", vrType) + } + + switch t.Kind() { + case reflect.Float32: + if !dc.Truncate && float64(float32(f)) != f { + return emptyValue, errCannotTruncate + } + + return reflect.ValueOf(float32(f)), nil + case reflect.Float64: + return reflect.ValueOf(f), nil + default: + return emptyValue, ValueDecoderError{ + Name: "FloatDecodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: reflect.Zero(t), + } + } +} + +// FloatDecodeValue is the ValueDecoderFunc for float types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "FloatDecodeValue", + Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, + Received: val, + } + } + + elem, err := dvd.floatDecodeType(ec, vr, val.Type()) + if err != nil { + return err + } + + val.SetFloat(elem.Float()) + return nil +} + +// StringDecodeValue is the ValueDecoderFunc for string types. +// +// Deprecated: StringDecodeValue is not registered by default. Use StringCodec.DecodeValue instead. +func (dvd DefaultValueDecoders) StringDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + var str string + var err error + switch vr.Type() { + // TODO(GODRIVER-577): Handle JavaScript and Symbol BSON types when allowed. + case bsontype.String: + str, err = vr.ReadString() + if err != nil { + return err + } + default: + return fmt.Errorf("cannot decode %v into a string type", vr.Type()) + } + if !val.CanSet() || val.Kind() != reflect.String { + return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} + } + + val.SetString(str) + return nil +} + +func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tJavaScript { + return emptyValue, ValueDecoderError{ + Name: "JavaScriptDecodeValue", + Types: []reflect.Type{tJavaScript}, + Received: reflect.Zero(t), + } + } + + var js string + var err error + switch vrType := vr.Type(); vrType { + case bsontype.JavaScript: + js, err = vr.ReadJavascript() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.JavaScript", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.JavaScript(js)), nil +} + +// JavaScriptDecodeValue is the ValueDecoderFunc for the primitive.JavaScript type. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) JavaScriptDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tJavaScript { + return ValueDecoderError{Name: "JavaScriptDecodeValue", Types: []reflect.Type{tJavaScript}, Received: val} + } + + elem, err := dvd.javaScriptDecodeType(dctx, vr, tJavaScript) + if err != nil { + return err + } + + val.SetString(elem.String()) + return nil +} + +func (DefaultValueDecoders) symbolDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tSymbol { + return emptyValue, ValueDecoderError{ + Name: "SymbolDecodeValue", + Types: []reflect.Type{tSymbol}, + Received: reflect.Zero(t), + } + } + + var symbol string + var err error + switch vrType := vr.Type(); vrType { + case bsontype.String: + symbol, err = vr.ReadString() + case bsontype.Symbol: + symbol, err = vr.ReadSymbol() + case bsontype.Binary: + data, subtype, err := vr.ReadBinary() + if err != nil { + return emptyValue, err + } + + if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { + return emptyValue, decodeBinaryError{subtype: subtype, typeName: "primitive.Symbol"} + } + symbol = string(data) + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.Symbol", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Symbol(symbol)), nil +} + +// SymbolDecodeValue is the ValueDecoderFunc for the primitive.Symbol type. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) SymbolDecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tSymbol { + return ValueDecoderError{Name: "SymbolDecodeValue", Types: []reflect.Type{tSymbol}, Received: val} + } + + elem, err := dvd.symbolDecodeType(dctx, vr, tSymbol) + if err != nil { + return err + } + + val.SetString(elem.String()) + return nil +} + +func (DefaultValueDecoders) binaryDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tBinary { + return emptyValue, ValueDecoderError{ + Name: "BinaryDecodeValue", + Types: []reflect.Type{tBinary}, + Received: reflect.Zero(t), + } + } + + var data []byte + var subtype byte + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Binary: + data, subtype, err = vr.ReadBinary() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a Binary", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Binary{Subtype: subtype, Data: data}), nil +} + +// BinaryDecodeValue is the ValueDecoderFunc for Binary. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) BinaryDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tBinary { + return ValueDecoderError{Name: "BinaryDecodeValue", Types: []reflect.Type{tBinary}, Received: val} + } + + elem, err := dvd.binaryDecodeType(dc, vr, tBinary) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) undefinedDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tUndefined { + return emptyValue, ValueDecoderError{ + Name: "UndefinedDecodeValue", + Types: []reflect.Type{tUndefined}, + Received: reflect.Zero(t), + } + } + + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Undefined: + err = vr.ReadUndefined() + case bsontype.Null: + err = vr.ReadNull() + default: + return emptyValue, fmt.Errorf("cannot decode %v into an Undefined", vr.Type()) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Undefined{}), nil +} + +// UndefinedDecodeValue is the ValueDecoderFunc for Undefined. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) UndefinedDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tUndefined { + return ValueDecoderError{Name: "UndefinedDecodeValue", Types: []reflect.Type{tUndefined}, Received: val} + } + + elem, err := dvd.undefinedDecodeType(dc, vr, tUndefined) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +// Accept both 12-byte string and pretty-printed 24-byte hex string formats. +func (dvd DefaultValueDecoders) objectIDDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tOID { + return emptyValue, ValueDecoderError{ + Name: "ObjectIDDecodeValue", + Types: []reflect.Type{tOID}, + Received: reflect.Zero(t), + } + } + + var oid primitive.ObjectID + var err error + switch vrType := vr.Type(); vrType { + case bsontype.ObjectID: + oid, err = vr.ReadObjectID() + if err != nil { + return emptyValue, err + } + case bsontype.String: + str, err := vr.ReadString() + if err != nil { + return emptyValue, err + } + if oid, err = primitive.ObjectIDFromHex(str); err == nil { + break + } + if len(str) != 12 { + return emptyValue, fmt.Errorf("an ObjectID string must be exactly 12 bytes long (got %v)", len(str)) + } + byteArr := []byte(str) + copy(oid[:], byteArr) + case bsontype.Null: + if err = vr.ReadNull(); err != nil { + return emptyValue, err + } + case bsontype.Undefined: + if err = vr.ReadUndefined(); err != nil { + return emptyValue, err + } + default: + return emptyValue, fmt.Errorf("cannot decode %v into an ObjectID", vrType) + } + + return reflect.ValueOf(oid), nil +} + +// ObjectIDDecodeValue is the ValueDecoderFunc for primitive.ObjectID. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) ObjectIDDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tOID { + return ValueDecoderError{Name: "ObjectIDDecodeValue", Types: []reflect.Type{tOID}, Received: val} + } + + elem, err := dvd.objectIDDecodeType(dc, vr, tOID) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) dateTimeDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tDateTime { + return emptyValue, ValueDecoderError{ + Name: "DateTimeDecodeValue", + Types: []reflect.Type{tDateTime}, + Received: reflect.Zero(t), + } + } + + var dt int64 + var err error + switch vrType := vr.Type(); vrType { + case bsontype.DateTime: + dt, err = vr.ReadDateTime() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a DateTime", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.DateTime(dt)), nil +} + +// DateTimeDecodeValue is the ValueDecoderFunc for DateTime. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) DateTimeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tDateTime { + return ValueDecoderError{Name: "DateTimeDecodeValue", Types: []reflect.Type{tDateTime}, Received: val} + } + + elem, err := dvd.dateTimeDecodeType(dc, vr, tDateTime) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) nullDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tNull { + return emptyValue, ValueDecoderError{ + Name: "NullDecodeValue", + Types: []reflect.Type{tNull}, + Received: reflect.Zero(t), + } + } + + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Undefined: + err = vr.ReadUndefined() + case bsontype.Null: + err = vr.ReadNull() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a Null", vr.Type()) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Null{}), nil +} + +// NullDecodeValue is the ValueDecoderFunc for Null. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) NullDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tNull { + return ValueDecoderError{Name: "NullDecodeValue", Types: []reflect.Type{tNull}, Received: val} + } + + elem, err := dvd.nullDecodeType(dc, vr, tNull) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) regexDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tRegex { + return emptyValue, ValueDecoderError{ + Name: "RegexDecodeValue", + Types: []reflect.Type{tRegex}, + Received: reflect.Zero(t), + } + } + + var pattern, options string + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Regex: + pattern, options, err = vr.ReadRegex() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a Regex", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Regex{Pattern: pattern, Options: options}), nil +} + +// RegexDecodeValue is the ValueDecoderFunc for Regex. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) RegexDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tRegex { + return ValueDecoderError{Name: "RegexDecodeValue", Types: []reflect.Type{tRegex}, Received: val} + } + + elem, err := dvd.regexDecodeType(dc, vr, tRegex) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) dBPointerDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tDBPointer { + return emptyValue, ValueDecoderError{ + Name: "DBPointerDecodeValue", + Types: []reflect.Type{tDBPointer}, + Received: reflect.Zero(t), + } + } + + var ns string + var pointer primitive.ObjectID + var err error + switch vrType := vr.Type(); vrType { + case bsontype.DBPointer: + ns, pointer, err = vr.ReadDBPointer() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a DBPointer", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.DBPointer{DB: ns, Pointer: pointer}), nil +} + +// DBPointerDecodeValue is the ValueDecoderFunc for DBPointer. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) DBPointerDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tDBPointer { + return ValueDecoderError{Name: "DBPointerDecodeValue", Types: []reflect.Type{tDBPointer}, Received: val} + } + + elem, err := dvd.dBPointerDecodeType(dc, vr, tDBPointer) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) timestampDecodeType(_ DecodeContext, vr bsonrw.ValueReader, reflectType reflect.Type) (reflect.Value, error) { + if reflectType != tTimestamp { + return emptyValue, ValueDecoderError{ + Name: "TimestampDecodeValue", + Types: []reflect.Type{tTimestamp}, + Received: reflect.Zero(reflectType), + } + } + + var t, incr uint32 + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Timestamp: + t, incr, err = vr.ReadTimestamp() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a Timestamp", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.Timestamp{T: t, I: incr}), nil +} + +// TimestampDecodeValue is the ValueDecoderFunc for Timestamp. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) TimestampDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tTimestamp { + return ValueDecoderError{Name: "TimestampDecodeValue", Types: []reflect.Type{tTimestamp}, Received: val} + } + + elem, err := dvd.timestampDecodeType(dc, vr, tTimestamp) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) minKeyDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tMinKey { + return emptyValue, ValueDecoderError{ + Name: "MinKeyDecodeValue", + Types: []reflect.Type{tMinKey}, + Received: reflect.Zero(t), + } + } + + var err error + switch vrType := vr.Type(); vrType { + case bsontype.MinKey: + err = vr.ReadMinKey() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a MinKey", vr.Type()) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.MinKey{}), nil +} + +// MinKeyDecodeValue is the ValueDecoderFunc for MinKey. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) MinKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tMinKey { + return ValueDecoderError{Name: "MinKeyDecodeValue", Types: []reflect.Type{tMinKey}, Received: val} + } + + elem, err := dvd.minKeyDecodeType(dc, vr, tMinKey) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (DefaultValueDecoders) maxKeyDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tMaxKey { + return emptyValue, ValueDecoderError{ + Name: "MaxKeyDecodeValue", + Types: []reflect.Type{tMaxKey}, + Received: reflect.Zero(t), + } + } + + var err error + switch vrType := vr.Type(); vrType { + case bsontype.MaxKey: + err = vr.ReadMaxKey() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a MaxKey", vr.Type()) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(primitive.MaxKey{}), nil +} + +// MaxKeyDecodeValue is the ValueDecoderFunc for MaxKey. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) MaxKeyDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tMaxKey { + return ValueDecoderError{Name: "MaxKeyDecodeValue", Types: []reflect.Type{tMaxKey}, Received: val} + } + + elem, err := dvd.maxKeyDecodeType(dc, vr, tMaxKey) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (dvd DefaultValueDecoders) decimal128DecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tDecimal { + return emptyValue, ValueDecoderError{ + Name: "Decimal128DecodeValue", + Types: []reflect.Type{tDecimal}, + Received: reflect.Zero(t), + } + } + + var d128 primitive.Decimal128 + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Decimal128: + d128, err = vr.ReadDecimal128() + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.Decimal128", vr.Type()) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(d128), nil +} + +// Decimal128DecodeValue is the ValueDecoderFunc for primitive.Decimal128. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) Decimal128DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tDecimal { + return ValueDecoderError{Name: "Decimal128DecodeValue", Types: []reflect.Type{tDecimal}, Received: val} + } + + elem, err := dvd.decimal128DecodeType(dctx, vr, tDecimal) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (dvd DefaultValueDecoders) jsonNumberDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tJSONNumber { + return emptyValue, ValueDecoderError{ + Name: "JSONNumberDecodeValue", + Types: []reflect.Type{tJSONNumber}, + Received: reflect.Zero(t), + } + } + + var jsonNum json.Number + var err error + switch vrType := vr.Type(); vrType { + case bsontype.Double: + f64, err := vr.ReadDouble() + if err != nil { + return emptyValue, err + } + jsonNum = json.Number(strconv.FormatFloat(f64, 'f', -1, 64)) + case bsontype.Int32: + i32, err := vr.ReadInt32() + if err != nil { + return emptyValue, err + } + jsonNum = json.Number(strconv.FormatInt(int64(i32), 10)) + case bsontype.Int64: + i64, err := vr.ReadInt64() + if err != nil { + return emptyValue, err + } + jsonNum = json.Number(strconv.FormatInt(i64, 10)) + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a json.Number", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(jsonNum), nil +} + +// JSONNumberDecodeValue is the ValueDecoderFunc for json.Number. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) JSONNumberDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tJSONNumber { + return ValueDecoderError{Name: "JSONNumberDecodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} + } + + elem, err := dvd.jsonNumberDecodeType(dc, vr, tJSONNumber) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (dvd DefaultValueDecoders) urlDecodeType(_ DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tURL { + return emptyValue, ValueDecoderError{ + Name: "URLDecodeValue", + Types: []reflect.Type{tURL}, + Received: reflect.Zero(t), + } + } + + urlPtr := &url.URL{} + var err error + switch vrType := vr.Type(); vrType { + case bsontype.String: + var str string // Declare str here to avoid shadowing err during the ReadString call. + str, err = vr.ReadString() + if err != nil { + return emptyValue, err + } + + urlPtr, err = url.Parse(str) + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a *url.URL", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(urlPtr).Elem(), nil +} + +// URLDecodeValue is the ValueDecoderFunc for url.URL. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tURL { + return ValueDecoderError{Name: "URLDecodeValue", Types: []reflect.Type{tURL}, Received: val} + } + + elem, err := dvd.urlDecodeType(dc, vr, tURL) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +// TimeDecodeValue is the ValueDecoderFunc for time.Time. +// +// Deprecated: TimeDecodeValue is not registered by default. Use TimeCodec.DecodeValue instead. +func (dvd DefaultValueDecoders) TimeDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if vr.Type() != bsontype.DateTime { + return fmt.Errorf("cannot decode %v into a time.Time", vr.Type()) + } + + dt, err := vr.ReadDateTime() + if err != nil { + return err + } + + if !val.CanSet() || val.Type() != tTime { + return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} + } + + val.Set(reflect.ValueOf(time.Unix(dt/1000, dt%1000*1000000).UTC())) + return nil +} + +// ByteSliceDecodeValue is the ValueDecoderFunc for []byte. +// +// Deprecated: ByteSliceDecodeValue is not registered by default. Use ByteSliceCodec.DecodeValue instead. +func (dvd DefaultValueDecoders) ByteSliceDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if vr.Type() != bsontype.Binary && vr.Type() != bsontype.Null { + return fmt.Errorf("cannot decode %v into a []byte", vr.Type()) + } + + if !val.CanSet() || val.Type() != tByteSlice { + return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} + } + + if vr.Type() == bsontype.Null { + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + } + + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + if subtype != 0x00 { + return fmt.Errorf("ByteSliceDecodeValue can only be used to decode subtype 0x00 for %s, got %v", bsontype.Binary, subtype) + } + + val.Set(reflect.ValueOf(data)) + return nil +} + +// MapDecodeValue is the ValueDecoderFunc for map[string]* types. +// +// Deprecated: MapDecodeValue is not registered by default. Use MapCodec.DecodeValue instead. +func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { + return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} + } + + switch vr.Type() { + case bsontype.Type(0), bsontype.EmbeddedDocument: + case bsontype.Null: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + default: + return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type()) + } + + dr, err := vr.ReadDocument() + if err != nil { + return err + } + + if val.IsNil() { + val.Set(reflect.MakeMap(val.Type())) + } + + eType := val.Type().Elem() + decoder, err := dc.LookupDecoder(eType) + if err != nil { + return err + } + + if eType == tEmpty { + dc.Ancestor = val.Type() + } + + keyType := val.Type().Key() + for { + key, vr, err := dr.ReadElement() + if err == bsonrw.ErrEOD { + break + } + if err != nil { + return err + } + + elem := reflect.New(eType).Elem() + + err = decoder.DecodeValue(dc, vr, elem) + if err != nil { + return err + } + + val.SetMapIndex(reflect.ValueOf(key).Convert(keyType), elem) + } + return nil +} + +// ArrayDecodeValue is the ValueDecoderFunc for array types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || val.Kind() != reflect.Array { + return ValueDecoderError{Name: "ArrayDecodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} + } + + switch vrType := vr.Type(); vrType { + case bsontype.Array: + case bsontype.Type(0), bsontype.EmbeddedDocument: + if val.Type().Elem() != tE { + return fmt.Errorf("cannot decode document into %s", val.Type()) + } + case bsontype.Binary: + if val.Type().Elem() != tByte { + return fmt.Errorf("ArrayDecodeValue can only be used to decode binary into a byte array, got %v", vrType) + } + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { + return fmt.Errorf("ArrayDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", bsontype.Binary, subtype) + } + + if len(data) > val.Len() { + return fmt.Errorf("more elements returned in array than can fit inside %s", val.Type()) + } + + for idx, elem := range data { + val.Index(idx).Set(reflect.ValueOf(elem)) + } + return nil + case bsontype.Null: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + case bsontype.Undefined: + val.Set(reflect.Zero(val.Type())) + return vr.ReadUndefined() + default: + return fmt.Errorf("cannot decode %v into an array", vrType) + } + + var elemsFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) ([]reflect.Value, error) + switch val.Type().Elem() { + case tE: + elemsFunc = dvd.decodeD + default: + elemsFunc = dvd.decodeDefault + } + + elems, err := elemsFunc(dc, vr, val) + if err != nil { + return err + } + + if len(elems) > val.Len() { + return fmt.Errorf("more elements returned in array than can fit inside %s, got %v elements", val.Type(), len(elems)) + } + + for idx, elem := range elems { + val.Index(idx).Set(elem) + } + + return nil +} + +// SliceDecodeValue is the ValueDecoderFunc for slice types. +// +// Deprecated: SliceDecodeValue is not registered by default. Use SliceCodec.DecodeValue instead. +func (dvd DefaultValueDecoders) SliceDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Kind() != reflect.Slice { + return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} + } + + switch vr.Type() { + case bsontype.Array: + case bsontype.Null: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + case bsontype.Type(0), bsontype.EmbeddedDocument: + if val.Type().Elem() != tE { + return fmt.Errorf("cannot decode document into %s", val.Type()) + } + default: + return fmt.Errorf("cannot decode %v into a slice", vr.Type()) + } + + var elemsFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) ([]reflect.Value, error) + switch val.Type().Elem() { + case tE: + dc.Ancestor = val.Type() + elemsFunc = dvd.decodeD + default: + elemsFunc = dvd.decodeDefault + } + + elems, err := elemsFunc(dc, vr, val) + if err != nil { + return err + } + + if val.IsNil() { + val.Set(reflect.MakeSlice(val.Type(), 0, len(elems))) + } + + val.SetLen(0) + val.Set(reflect.Append(val, elems...)) + + return nil +} + +// ValueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || (!val.Type().Implements(tValueUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tValueUnmarshaler)) { + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} + } + + if val.Kind() == reflect.Ptr && val.IsNil() { + if !val.CanSet() { + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} + } + val.Set(reflect.New(val.Type().Elem())) + } + + if !val.Type().Implements(tValueUnmarshaler) { + if !val.CanAddr() { + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} + } + val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. + } + + t, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) + if err != nil { + return err + } + + fn := val.Convert(tValueUnmarshaler).MethodByName("UnmarshalBSONValue") + errVal := fn.Call([]reflect.Value{reflect.ValueOf(t), reflect.ValueOf(src)})[0] + if !errVal.IsNil() { + return errVal.Interface().(error) + } + return nil +} + +// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.IsValid() || (!val.Type().Implements(tUnmarshaler) && !reflect.PtrTo(val.Type()).Implements(tUnmarshaler)) { + return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} + } + + if val.Kind() == reflect.Ptr && val.IsNil() { + if !val.CanSet() { + return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} + } + val.Set(reflect.New(val.Type().Elem())) + } + + _, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) + if err != nil { + return err + } + + // If the target Go value is a pointer and the BSON field value is empty, set the value to the + // zero value of the pointer (nil) and don't call UnmarshalBSON. UnmarshalBSON has no way to + // change the pointer value from within the function (only the value at the pointer address), + // so it can't set the pointer to "nil" itself. Since the most common Go value for an empty BSON + // field value is "nil", we set "nil" here and don't call UnmarshalBSON. This behavior matches + // the behavior of the Go "encoding/json" unmarshaler when the target Go value is a pointer and + // the JSON field value is "null". + if val.Kind() == reflect.Ptr && len(src) == 0 { + val.Set(reflect.Zero(val.Type())) + return nil + } + + if !val.Type().Implements(tUnmarshaler) { + if !val.CanAddr() { + return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} + } + val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. + } + + fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON") + errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0] + if !errVal.IsNil() { + return errVal.Interface().(error) + } + return nil +} + +// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. +// +// Deprecated: EmptyInterfaceDecodeValue is not registered by default. Use EmptyInterfaceCodec.DecodeValue instead. +func (dvd DefaultValueDecoders) EmptyInterfaceDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tEmpty { + return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} + } + + rtype, err := dc.LookupTypeMapEntry(vr.Type()) + if err != nil { + switch vr.Type() { + case bsontype.EmbeddedDocument: + if dc.Ancestor != nil { + rtype = dc.Ancestor + break + } + rtype = tD + case bsontype.Null: + val.Set(reflect.Zero(val.Type())) + return vr.ReadNull() + default: + return err + } + } + + decoder, err := dc.LookupDecoder(rtype) + if err != nil { + return err + } + + elem := reflect.New(rtype).Elem() + err = decoder.DecodeValue(dc, vr, elem) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +// CoreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (DefaultValueDecoders) CoreDocumentDecodeValue(_ DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tCoreDocument { + return ValueDecoderError{Name: "CoreDocumentDecodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} + } + + if val.IsNil() { + val.Set(reflect.MakeSlice(val.Type(), 0, 0)) + } + + val.SetLen(0) + + cdoc, err := bsonrw.Copier{}.AppendDocumentBytes(val.Interface().(bsoncore.Document), vr) + val.Set(reflect.ValueOf(cdoc)) + return err +} + +func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) ([]reflect.Value, error) { + elems := make([]reflect.Value, 0) + + ar, err := vr.ReadArray() + if err != nil { + return nil, err + } + + eType := val.Type().Elem() + + decoder, err := dc.LookupDecoder(eType) + if err != nil { + return nil, err + } + eTypeDecoder, _ := decoder.(typeDecoder) + + idx := 0 + for { + vr, err := ar.ReadValue() + if err == bsonrw.ErrEOA { + break + } + if err != nil { + return nil, err + } + + elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) + if err != nil { + return nil, newDecodeError(strconv.Itoa(idx), err) + } + elems = append(elems, elem) + idx++ + } + + return elems, nil +} + +func (dvd DefaultValueDecoders) readCodeWithScope(dc DecodeContext, vr bsonrw.ValueReader) (primitive.CodeWithScope, error) { + var cws primitive.CodeWithScope + + code, dr, err := vr.ReadCodeWithScope() + if err != nil { + return cws, err + } + + scope := reflect.New(tD).Elem() + elems, err := dvd.decodeElemsFromDocumentReader(dc, dr) + if err != nil { + return cws, err + } + + scope.Set(reflect.MakeSlice(tD, 0, len(elems))) + scope.Set(reflect.Append(scope, elems...)) + + cws = primitive.CodeWithScope{ + Code: primitive.JavaScript(code), + Scope: scope.Interface().(primitive.D), + } + return cws, nil +} + +func (dvd DefaultValueDecoders) codeWithScopeDecodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { + if t != tCodeWithScope { + return emptyValue, ValueDecoderError{ + Name: "CodeWithScopeDecodeValue", + Types: []reflect.Type{tCodeWithScope}, + Received: reflect.Zero(t), + } + } + + var cws primitive.CodeWithScope + var err error + switch vrType := vr.Type(); vrType { + case bsontype.CodeWithScope: + cws, err = dvd.readCodeWithScope(dc, vr) + case bsontype.Null: + err = vr.ReadNull() + case bsontype.Undefined: + err = vr.ReadUndefined() + default: + return emptyValue, fmt.Errorf("cannot decode %v into a primitive.CodeWithScope", vrType) + } + if err != nil { + return emptyValue, err + } + + return reflect.ValueOf(cws), nil +} + +// CodeWithScopeDecodeValue is the ValueDecoderFunc for CodeWithScope. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value decoders registered. +func (dvd DefaultValueDecoders) CodeWithScopeDecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tCodeWithScope { + return ValueDecoderError{Name: "CodeWithScopeDecodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} + } + + elem, err := dvd.codeWithScopeDecodeType(dc, vr, tCodeWithScope) + if err != nil { + return err + } + + val.Set(elem) + return nil +} + +func (dvd DefaultValueDecoders) decodeD(dc DecodeContext, vr bsonrw.ValueReader, _ reflect.Value) ([]reflect.Value, error) { + switch vr.Type() { + case bsontype.Type(0), bsontype.EmbeddedDocument: + default: + return nil, fmt.Errorf("cannot decode %v into a D", vr.Type()) + } + + dr, err := vr.ReadDocument() + if err != nil { + return nil, err + } + + return dvd.decodeElemsFromDocumentReader(dc, dr) +} + +func (DefaultValueDecoders) decodeElemsFromDocumentReader(dc DecodeContext, dr bsonrw.DocumentReader) ([]reflect.Value, error) { + decoder, err := dc.LookupDecoder(tEmpty) + if err != nil { + return nil, err + } + + elems := make([]reflect.Value, 0) + for { + key, vr, err := dr.ReadElement() + if err == bsonrw.ErrEOD { + break + } + if err != nil { + return nil, err + } + + val := reflect.New(tEmpty).Elem() + err = decoder.DecodeValue(dc, vr, val) + if err != nil { + return nil, newDecodeError(key, err) + } + + elems = append(elems, reflect.ValueOf(primitive.E{Key: key, Value: val.Interface()})) + } + + return elems, nil +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go new file mode 100644 index 0000000000..7d526c4ef8 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsoncodec/default_value_encoders.go @@ -0,0 +1,844 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncodec + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "net/url" + "reflect" + "sync" + "time" + + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +var defaultValueEncoders DefaultValueEncoders + +var bvwPool = bsonrw.NewBSONValueWriterPool() + +var errInvalidValue = errors.New("cannot encode invalid element") + +var sliceWriterPool = sync.Pool{ + New: func() interface{} { + sw := make(bsonrw.SliceWriter, 0) + return &sw + }, +} + +func encodeElement(ec EncodeContext, dw bsonrw.DocumentWriter, e primitive.E) error { + vw, err := dw.WriteDocumentElement(e.Key) + if err != nil { + return err + } + + if e.Value == nil { + return vw.WriteNull() + } + encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value)) + if err != nil { + return err + } + + err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value)) + if err != nil { + return err + } + return nil +} + +// DefaultValueEncoders is a namespace type for the default ValueEncoders used +// when creating a registry. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +type DefaultValueEncoders struct{} + +// RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with +// the provided RegistryBuilder. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { + if rb == nil { + panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) + } + rb. + RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec). + RegisterTypeEncoder(tTime, defaultTimeCodec). + RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec). + RegisterTypeEncoder(tCoreArray, defaultArrayCodec). + RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)). + RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)). + RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)). + RegisterTypeEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)). + RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)). + RegisterTypeEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)). + RegisterTypeEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)). + RegisterTypeEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)). + RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)). + RegisterTypeEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)). + RegisterTypeEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)). + RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)). + RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)). + RegisterTypeEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)). + RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)). + RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)). + RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)). + RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)). + RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)). + RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)). + RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)). + RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)). + RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)). + RegisterDefaultEncoder(reflect.Uint, defaultUIntCodec). + RegisterDefaultEncoder(reflect.Uint8, defaultUIntCodec). + RegisterDefaultEncoder(reflect.Uint16, defaultUIntCodec). + RegisterDefaultEncoder(reflect.Uint32, defaultUIntCodec). + RegisterDefaultEncoder(reflect.Uint64, defaultUIntCodec). + RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)). + RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)). + RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)). + RegisterDefaultEncoder(reflect.Map, defaultMapCodec). + RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec). + RegisterDefaultEncoder(reflect.String, defaultStringCodec). + RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()). + RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()). + RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)). + RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)). + RegisterHookEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue)) +} + +// BooleanEncodeValue is the ValueEncoderFunc for bool types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) BooleanEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Kind() != reflect.Bool { + return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} + } + return vw.WriteBoolean(val.Bool()) +} + +func fitsIn32Bits(i int64) bool { + return math.MinInt32 <= i && i <= math.MaxInt32 +} + +// IntEncodeValue is the ValueEncoderFunc for int types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + switch val.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32: + return vw.WriteInt32(int32(val.Int())) + case reflect.Int: + i64 := val.Int() + if fitsIn32Bits(i64) { + return vw.WriteInt32(int32(i64)) + } + return vw.WriteInt64(i64) + case reflect.Int64: + i64 := val.Int() + if ec.MinSize && fitsIn32Bits(i64) { + return vw.WriteInt32(int32(i64)) + } + return vw.WriteInt64(i64) + } + + return ValueEncoderError{ + Name: "IntEncodeValue", + Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Received: val, + } +} + +// UintEncodeValue is the ValueEncoderFunc for uint types. +// +// Deprecated: UintEncodeValue is not registered by default. Use UintCodec.EncodeValue instead. +func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + switch val.Kind() { + case reflect.Uint8, reflect.Uint16: + return vw.WriteInt32(int32(val.Uint())) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + u64 := val.Uint() + if ec.MinSize && u64 <= math.MaxInt32 { + return vw.WriteInt32(int32(u64)) + } + if u64 > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", u64) + } + return vw.WriteInt64(int64(u64)) + } + + return ValueEncoderError{ + Name: "UintEncodeValue", + Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Received: val, + } +} + +// FloatEncodeValue is the ValueEncoderFunc for float types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) FloatEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + switch val.Kind() { + case reflect.Float32, reflect.Float64: + return vw.WriteDouble(val.Float()) + } + + return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} +} + +// StringEncodeValue is the ValueEncoderFunc for string types. +// +// Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead. +func (dve DefaultValueEncoders) StringEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if val.Kind() != reflect.String { + return ValueEncoderError{ + Name: "StringEncodeValue", + Kinds: []reflect.Kind{reflect.String}, + Received: val, + } + } + + return vw.WriteString(val.String()) +} + +// ObjectIDEncodeValue is the ValueEncoderFunc for primitive.ObjectID. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) ObjectIDEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tOID { + return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} + } + return vw.WriteObjectID(val.Interface().(primitive.ObjectID)) +} + +// Decimal128EncodeValue is the ValueEncoderFunc for primitive.Decimal128. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) Decimal128EncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tDecimal { + return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} + } + return vw.WriteDecimal128(val.Interface().(primitive.Decimal128)) +} + +// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tJSONNumber { + return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} + } + jsnum := val.Interface().(json.Number) + + // Attempt int first, then float64 + if i64, err := jsnum.Int64(); err == nil { + return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64)) + } + + f64, err := jsnum.Float64() + if err != nil { + return err + } + + return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64)) +} + +// URLEncodeValue is the ValueEncoderFunc for url.URL. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) URLEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tURL { + return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} + } + u := val.Interface().(url.URL) + return vw.WriteString(u.String()) +} + +// TimeEncodeValue is the ValueEncoderFunc for time.TIme. +// +// Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead. +func (dve DefaultValueEncoders) TimeEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tTime { + return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} + } + tt := val.Interface().(time.Time) + dt := primitive.NewDateTimeFromTime(tt) + return vw.WriteDateTime(int64(dt)) +} + +// ByteSliceEncodeValue is the ValueEncoderFunc for []byte. +// +// Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead. +func (dve DefaultValueEncoders) ByteSliceEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tByteSlice { + return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} + } + if val.IsNil() { + return vw.WriteNull() + } + return vw.WriteBinary(val.Interface().([]byte)) +} + +// MapEncodeValue is the ValueEncoderFunc for map[string]* types. +// +// Deprecated: MapEncodeValue is not registered by default. Use MapCodec.EncodeValue instead. +func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { + return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} + } + + if val.IsNil() { + // If we have a nill map but we can't WriteNull, that means we're probably trying to encode + // to a TopLevel document. We can't currently tell if this is what actually happened, but if + // there's a deeper underlying problem, the error will also be returned from WriteDocument, + // so just continue. The operations on a map reflection value are valid, so we can call + // MapKeys within mapEncodeValue without a problem. + err := vw.WriteNull() + if err == nil { + return nil + } + } + + dw, err := vw.WriteDocument() + if err != nil { + return err + } + + return dve.mapEncodeValue(ec, dw, val, nil) +} + +// mapEncodeValue handles encoding of the values of a map. The collisionFn returns +// true if the provided key exists, this is mainly used for inline maps in the +// struct codec. +func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { + + elemType := val.Type().Elem() + encoder, err := ec.LookupEncoder(elemType) + if err != nil && elemType.Kind() != reflect.Interface { + return err + } + + keys := val.MapKeys() + for _, key := range keys { + if collisionFn != nil && collisionFn(key.String()) { + return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) + } + + currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key)) + if lookupErr != nil && lookupErr != errInvalidValue { + return lookupErr + } + + vw, err := dw.WriteDocumentElement(key.String()) + if err != nil { + return err + } + + if lookupErr == errInvalidValue { + err = vw.WriteNull() + if err != nil { + return err + } + continue + } + + err = currEncoder.EncodeValue(ec, vw, currVal) + if err != nil { + return err + } + } + + return dw.WriteDocumentEnd() +} + +// ArrayEncodeValue is the ValueEncoderFunc for array types. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Kind() != reflect.Array { + return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} + } + + // If we have a []primitive.E we want to treat it as a document instead of as an array. + if val.Type().Elem() == tE { + dw, err := vw.WriteDocument() + if err != nil { + return err + } + + for idx := 0; idx < val.Len(); idx++ { + e := val.Index(idx).Interface().(primitive.E) + err = encodeElement(ec, dw, e) + if err != nil { + return err + } + } + + return dw.WriteDocumentEnd() + } + + // If we have a []byte we want to treat it as a binary instead of as an array. + if val.Type().Elem() == tByte { + var byteSlice []byte + for idx := 0; idx < val.Len(); idx++ { + byteSlice = append(byteSlice, val.Index(idx).Interface().(byte)) + } + return vw.WriteBinary(byteSlice) + } + + aw, err := vw.WriteArray() + if err != nil { + return err + } + + elemType := val.Type().Elem() + encoder, err := ec.LookupEncoder(elemType) + if err != nil && elemType.Kind() != reflect.Interface { + return err + } + + for idx := 0; idx < val.Len(); idx++ { + currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) + if lookupErr != nil && lookupErr != errInvalidValue { + return lookupErr + } + + vw, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if lookupErr == errInvalidValue { + err = vw.WriteNull() + if err != nil { + return err + } + continue + } + + err = currEncoder.EncodeValue(ec, vw, currVal) + if err != nil { + return err + } + } + return aw.WriteArrayEnd() +} + +// SliceEncodeValue is the ValueEncoderFunc for slice types. +// +// Deprecated: SliceEncodeValue is not registered by default. Use SliceCodec.EncodeValue instead. +func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Kind() != reflect.Slice { + return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} + } + + if val.IsNil() { + return vw.WriteNull() + } + + // If we have a []primitive.E we want to treat it as a document instead of as an array. + if val.Type().ConvertibleTo(tD) { + d := val.Convert(tD).Interface().(primitive.D) + + dw, err := vw.WriteDocument() + if err != nil { + return err + } + + for _, e := range d { + err = encodeElement(ec, dw, e) + if err != nil { + return err + } + } + + return dw.WriteDocumentEnd() + } + + aw, err := vw.WriteArray() + if err != nil { + return err + } + + elemType := val.Type().Elem() + encoder, err := ec.LookupEncoder(elemType) + if err != nil && elemType.Kind() != reflect.Interface { + return err + } + + for idx := 0; idx < val.Len(); idx++ { + currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) + if lookupErr != nil && lookupErr != errInvalidValue { + return lookupErr + } + + vw, err := aw.WriteArrayElement() + if err != nil { + return err + } + + if lookupErr == errInvalidValue { + err = vw.WriteNull() + if err != nil { + return err + } + continue + } + + err = currEncoder.EncodeValue(ec, vw, currVal) + if err != nil { + return err + } + } + return aw.WriteArrayEnd() +} + +func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { + if origEncoder != nil || (currVal.Kind() != reflect.Interface) { + return origEncoder, currVal, nil + } + currVal = currVal.Elem() + if !currVal.IsValid() { + return nil, currVal, errInvalidValue + } + currEncoder, err := ec.LookupEncoder(currVal.Type()) + + return currEncoder, currVal, err +} + +// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}. +// +// Deprecated: EmptyInterfaceEncodeValue is not registered by default. Use EmptyInterfaceCodec.EncodeValue instead. +func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tEmpty { + return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} + } + + if val.IsNil() { + return vw.WriteNull() + } + encoder, err := ec.LookupEncoder(val.Elem().Type()) + if err != nil { + return err + } + + return encoder.EncodeValue(ec, vw, val.Elem()) +} + +// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + // Either val or a pointer to val must implement ValueMarshaler + switch { + case !val.IsValid(): + return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val} + case val.Type().Implements(tValueMarshaler): + // If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer + if isImplementationNil(val, tValueMarshaler) { + return vw.WriteNull() + } + case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr(): + val = val.Addr() + default: + return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val} + } + + fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue") + returns := fn.Call(nil) + if !returns[2].IsNil() { + return returns[2].Interface().(error) + } + t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte) + return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data) +} + +// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + // Either val or a pointer to val must implement Marshaler + switch { + case !val.IsValid(): + return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val} + case val.Type().Implements(tMarshaler): + // If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer + if isImplementationNil(val, tMarshaler) { + return vw.WriteNull() + } + case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr(): + val = val.Addr() + default: + return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val} + } + + fn := val.Convert(tMarshaler).MethodByName("MarshalBSON") + returns := fn.Call(nil) + if !returns[1].IsNil() { + return returns[1].Interface().(error) + } + data := returns[0].Interface().([]byte) + return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data) +} + +// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + // Either val or a pointer to val must implement Proxy + switch { + case !val.IsValid(): + return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} + case val.Type().Implements(tProxy): + // If Proxy is implemented on a concrete type, make sure that val isn't a nil pointer + if isImplementationNil(val, tProxy) { + return vw.WriteNull() + } + case reflect.PtrTo(val.Type()).Implements(tProxy) && val.CanAddr(): + val = val.Addr() + default: + return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} + } + + fn := val.Convert(tProxy).MethodByName("ProxyBSON") + returns := fn.Call(nil) + if !returns[1].IsNil() { + return returns[1].Interface().(error) + } + data := returns[0] + var encoder ValueEncoder + var err error + if data.Elem().IsValid() { + encoder, err = ec.LookupEncoder(data.Elem().Type()) + } else { + encoder, err = ec.LookupEncoder(nil) + } + if err != nil { + return err + } + return encoder.EncodeValue(ec, vw, data.Elem()) +} + +// JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) JavaScriptEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tJavaScript { + return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} + } + + return vw.WriteJavascript(val.String()) +} + +// SymbolEncodeValue is the ValueEncoderFunc for the primitive.Symbol type. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) SymbolEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tSymbol { + return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} + } + + return vw.WriteSymbol(val.String()) +} + +// BinaryEncodeValue is the ValueEncoderFunc for Binary. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) BinaryEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tBinary { + return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} + } + b := val.Interface().(primitive.Binary) + + return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) +} + +// UndefinedEncodeValue is the ValueEncoderFunc for Undefined. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) UndefinedEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tUndefined { + return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} + } + + return vw.WriteUndefined() +} + +// DateTimeEncodeValue is the ValueEncoderFunc for DateTime. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) DateTimeEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tDateTime { + return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} + } + + return vw.WriteDateTime(val.Int()) +} + +// NullEncodeValue is the ValueEncoderFunc for Null. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) NullEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tNull { + return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} + } + + return vw.WriteNull() +} + +// RegexEncodeValue is the ValueEncoderFunc for Regex. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) RegexEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tRegex { + return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} + } + + regex := val.Interface().(primitive.Regex) + + return vw.WriteRegex(regex.Pattern, regex.Options) +} + +// DBPointerEncodeValue is the ValueEncoderFunc for DBPointer. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) DBPointerEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tDBPointer { + return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} + } + + dbp := val.Interface().(primitive.DBPointer) + + return vw.WriteDBPointer(dbp.DB, dbp.Pointer) +} + +// TimestampEncodeValue is the ValueEncoderFunc for Timestamp. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) TimestampEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tTimestamp { + return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} + } + + ts := val.Interface().(primitive.Timestamp) + + return vw.WriteTimestamp(ts.T, ts.I) +} + +// MinKeyEncodeValue is the ValueEncoderFunc for MinKey. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) MinKeyEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tMinKey { + return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} + } + + return vw.WriteMinKey() +} + +// MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) MaxKeyEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tMaxKey { + return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} + } + + return vw.WriteMaxKey() +} + +// CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (DefaultValueEncoders) CoreDocumentEncodeValue(_ EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tCoreDocument { + return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} + } + + cdoc := val.Interface().(bsoncore.Document) + + return bsonrw.Copier{}.CopyDocumentFromBytes(vw, cdoc) +} + +// CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. +// +// Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default +// value encoders registered. +func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tCodeWithScope { + return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} + } + + cws := val.Interface().(primitive.CodeWithScope) + + dw, err := vw.WriteCodeWithScope(string(cws.Code)) + if err != nil { + return err + } + + sw := sliceWriterPool.Get().(*bsonrw.SliceWriter) + defer sliceWriterPool.Put(sw) + *sw = (*sw)[:0] + + scopeVW := bvwPool.Get(sw) + defer bvwPool.Put(scopeVW) + + encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) + if err != nil { + return err + } + + err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope)) + if err != nil { + return err + } + + err = bsonrw.Copier{}.CopyBytesToDocumentWriter(dw, *sw) + if err != nil { + return err + } + return dw.WriteDocumentEnd() +} + +// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type +func isImplementationNil(val reflect.Value, inter reflect.Type) bool { + vt := val.Type() + for vt.Kind() == reflect.Ptr { + vt = vt.Elem() + } + return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/bsontype/bsontype.go b/vendor/go.mongodb.org/mongo-driver/bson/bsontype/bsontype.go new file mode 100644 index 0000000000..f38c263a4c --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/bsontype/bsontype.go @@ -0,0 +1,104 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package bsontype is a utility package that contains types for each BSON type and the +// a stringifier for the Type to enable easier debugging when working with BSON. +package bsontype // import "go.mongodb.org/mongo-driver/bson/bsontype" + +// BSON element types as described in https://bsonspec.org/spec.html. +// +// Deprecated: Use bson.Type* constants instead. +const ( + Double Type = 0x01 + String Type = 0x02 + EmbeddedDocument Type = 0x03 + Array Type = 0x04 + Binary Type = 0x05 + Undefined Type = 0x06 + ObjectID Type = 0x07 + Boolean Type = 0x08 + DateTime Type = 0x09 + Null Type = 0x0A + Regex Type = 0x0B + DBPointer Type = 0x0C + JavaScript Type = 0x0D + Symbol Type = 0x0E + CodeWithScope Type = 0x0F + Int32 Type = 0x10 + Timestamp Type = 0x11 + Int64 Type = 0x12 + Decimal128 Type = 0x13 + MinKey Type = 0xFF + MaxKey Type = 0x7F +) + +// BSON binary element subtypes as described in https://bsonspec.org/spec.html. +// +// Deprecated: Use the bson.TypeBinary* constants instead. +const ( + BinaryGeneric byte = 0x00 + BinaryFunction byte = 0x01 + BinaryBinaryOld byte = 0x02 + BinaryUUIDOld byte = 0x03 + BinaryUUID byte = 0x04 + BinaryMD5 byte = 0x05 + BinaryEncrypted byte = 0x06 + BinaryColumn byte = 0x07 + BinaryUserDefined byte = 0x80 +) + +// Type represents a BSON type. +type Type byte + +// String returns the string representation of the BSON type's name. +func (bt Type) String() string { + switch bt { + case '\x01': + return "double" + case '\x02': + return "string" + case '\x03': + return "embedded document" + case '\x04': + return "array" + case '\x05': + return "binary" + case '\x06': + return "undefined" + case '\x07': + return "objectID" + case '\x08': + return "boolean" + case '\x09': + return "UTC datetime" + case '\x0A': + return "null" + case '\x0B': + return "regex" + case '\x0C': + return "dbPointer" + case '\x0D': + return "javascript" + case '\x0E': + return "symbol" + case '\x0F': + return "code with scope" + case '\x10': + return "32-bit integer" + case '\x11': + return "timestamp" + case '\x12': + return "64-bit integer" + case '\x13': + return "128-bit decimal" + case '\xFF': + return "min key" + case '\x7F': + return "max key" + default: + return "invalid" + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/primitive_codecs.go b/vendor/go.mongodb.org/mongo-driver/bson/primitive_codecs.go new file mode 100644 index 0000000000..6b9602589c --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/primitive_codecs.go @@ -0,0 +1,110 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "errors" + "reflect" + + "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" +) + +var tRawValue = reflect.TypeOf(RawValue{}) +var tRaw = reflect.TypeOf(Raw(nil)) + +var primitiveCodecs PrimitiveCodecs + +// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types +// defined in this package. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +type PrimitiveCodecs struct{} + +// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs +// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) { + if rb == nil { + panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) + } + + rb. + RegisterTypeEncoder(tRawValue, bsoncodec.ValueEncoderFunc(pc.RawValueEncodeValue)). + RegisterTypeEncoder(tRaw, bsoncodec.ValueEncoderFunc(pc.RawEncodeValue)). + RegisterTypeDecoder(tRawValue, bsoncodec.ValueDecoderFunc(pc.RawValueDecodeValue)). + RegisterTypeDecoder(tRaw, bsoncodec.ValueDecoderFunc(pc.RawDecodeValue)) +} + +// RawValueEncodeValue is the ValueEncoderFunc for RawValue. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (PrimitiveCodecs) RawValueEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tRawValue { + return bsoncodec.ValueEncoderError{Name: "RawValueEncodeValue", Types: []reflect.Type{tRawValue}, Received: val} + } + + rawvalue := val.Interface().(RawValue) + + return bsonrw.Copier{}.CopyValueFromBytes(vw, rawvalue.Type, rawvalue.Value) +} + +// RawValueDecodeValue is the ValueDecoderFunc for RawValue. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (PrimitiveCodecs) RawValueDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tRawValue { + return bsoncodec.ValueDecoderError{Name: "RawValueDecodeValue", Types: []reflect.Type{tRawValue}, Received: val} + } + + t, value, err := bsonrw.Copier{}.CopyValueToBytes(vr) + if err != nil { + return err + } + + val.Set(reflect.ValueOf(RawValue{Type: t, Value: value})) + return nil +} + +// RawEncodeValue is the ValueEncoderFunc for Reader. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (PrimitiveCodecs) RawEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { + if !val.IsValid() || val.Type() != tRaw { + return bsoncodec.ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} + } + + rdr := val.Interface().(Raw) + + return bsonrw.Copier{}.CopyDocumentFromBytes(vw, rdr) +} + +// RawDecodeValue is the ValueDecoderFunc for Reader. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders +// registered. +func (PrimitiveCodecs) RawDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if !val.CanSet() || val.Type() != tRaw { + return bsoncodec.ValueDecoderError{Name: "RawDecodeValue", Types: []reflect.Type{tRaw}, Received: val} + } + + if val.IsNil() { + val.Set(reflect.MakeSlice(val.Type(), 0, 0)) + } + + val.SetLen(0) + + rdr, err := bsonrw.Copier{}.AppendDocumentBytes(val.Interface().(Raw), vr) + val.Set(reflect.ValueOf(rdr)) + return err +} diff --git a/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go b/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go new file mode 100644 index 0000000000..6627294c4d --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/bson/raw_value.go @@ -0,0 +1,314 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "bytes" + "errors" + "fmt" + "reflect" + "time" + + "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +// ErrNilContext is returned when the provided DecodeContext is nil. +var ErrNilContext = errors.New("DecodeContext cannot be nil") + +// ErrNilRegistry is returned when the provided registry is nil. +var ErrNilRegistry = errors.New("Registry cannot be nil") + +// RawValue is a raw encoded BSON value. It can be used to delay BSON value decoding or precompute +// BSON encoded value. Type is the BSON type of the value and Value is the raw encoded BSON value. +// +// A RawValue must be an individual BSON value. Use the Raw type for full BSON documents. +type RawValue struct { + Type bsontype.Type + Value []byte + + r *bsoncodec.Registry +} + +// Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an +// error is returned. This method will use the registry used to create the RawValue, if the RawValue +// was created from partial BSON processing, or it will use the default registry. Users wishing to +// specify the registry to use should use UnmarshalWithRegistry. +func (rv RawValue) Unmarshal(val interface{}) error { + reg := rv.r + if reg == nil { + reg = DefaultRegistry + } + return rv.UnmarshalWithRegistry(reg, val) +} + +// Equal compares rv and rv2 and returns true if they are equal. +func (rv RawValue) Equal(rv2 RawValue) bool { + if rv.Type != rv2.Type { + return false + } + + if !bytes.Equal(rv.Value, rv2.Value) { + return false + } + + return true +} + +// UnmarshalWithRegistry performs the same unmarshalling as Unmarshal but uses the provided registry +// instead of the one attached or the default registry. +func (rv RawValue) UnmarshalWithRegistry(r *bsoncodec.Registry, val interface{}) error { + if r == nil { + return ErrNilRegistry + } + + vr := bsonrw.NewBSONValueReader(rv.Type, rv.Value) + rval := reflect.ValueOf(val) + if rval.Kind() != reflect.Ptr { + return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval) + } + rval = rval.Elem() + dec, err := r.LookupDecoder(rval.Type()) + if err != nil { + return err + } + return dec.DecodeValue(bsoncodec.DecodeContext{Registry: r}, vr, rval) +} + +// UnmarshalWithContext performs the same unmarshalling as Unmarshal but uses the provided DecodeContext +// instead of the one attached or the default registry. +func (rv RawValue) UnmarshalWithContext(dc *bsoncodec.DecodeContext, val interface{}) error { + if dc == nil { + return ErrNilContext + } + + vr := bsonrw.NewBSONValueReader(rv.Type, rv.Value) + rval := reflect.ValueOf(val) + if rval.Kind() != reflect.Ptr { + return fmt.Errorf("argument to Unmarshal* must be a pointer to a type, but got %v", rval) + } + rval = rval.Elem() + dec, err := dc.LookupDecoder(rval.Type()) + if err != nil { + return err + } + return dec.DecodeValue(*dc, vr, rval) +} + +func convertFromCoreValue(v bsoncore.Value) RawValue { return RawValue{Type: v.Type, Value: v.Data} } +func convertToCoreValue(v RawValue) bsoncore.Value { + return bsoncore.Value{Type: v.Type, Data: v.Value} +} + +// Validate ensures the value is a valid BSON value. +func (rv RawValue) Validate() error { return convertToCoreValue(rv).Validate() } + +// IsNumber returns true if the type of v is a numeric BSON type. +func (rv RawValue) IsNumber() bool { return convertToCoreValue(rv).IsNumber() } + +// String implements the fmt.String interface. This method will return values in extended JSON +// format. If the value is not valid, this returns an empty string +func (rv RawValue) String() string { return convertToCoreValue(rv).String() } + +// DebugString outputs a human readable version of Document. It will attempt to stringify the +// valid components of the document even if the entire document is not valid. +func (rv RawValue) DebugString() string { return convertToCoreValue(rv).DebugString() } + +// Double returns the float64 value for this element. +// It panics if e's BSON type is not bsontype.Double. +func (rv RawValue) Double() float64 { return convertToCoreValue(rv).Double() } + +// DoubleOK is the same as Double, but returns a boolean instead of panicking. +func (rv RawValue) DoubleOK() (float64, bool) { return convertToCoreValue(rv).DoubleOK() } + +// StringValue returns the string value for this element. +// It panics if e's BSON type is not bsontype.String. +// +// NOTE: This method is called StringValue to avoid a collision with the String method which +// implements the fmt.Stringer interface. +func (rv RawValue) StringValue() string { return convertToCoreValue(rv).StringValue() } + +// StringValueOK is the same as StringValue, but returns a boolean instead of +// panicking. +func (rv RawValue) StringValueOK() (string, bool) { return convertToCoreValue(rv).StringValueOK() } + +// Document returns the BSON document the Value represents as a Document. It panics if the +// value is a BSON type other than document. +func (rv RawValue) Document() Raw { return Raw(convertToCoreValue(rv).Document()) } + +// DocumentOK is the same as Document, except it returns a boolean +// instead of panicking. +func (rv RawValue) DocumentOK() (Raw, bool) { + doc, ok := convertToCoreValue(rv).DocumentOK() + return Raw(doc), ok +} + +// Array returns the BSON array the Value represents as an Array. It panics if the +// value is a BSON type other than array. +func (rv RawValue) Array() Raw { return Raw(convertToCoreValue(rv).Array()) } + +// ArrayOK is the same as Array, except it returns a boolean instead +// of panicking. +func (rv RawValue) ArrayOK() (Raw, bool) { + doc, ok := convertToCoreValue(rv).ArrayOK() + return Raw(doc), ok +} + +// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type +// other than binary. +func (rv RawValue) Binary() (subtype byte, data []byte) { return convertToCoreValue(rv).Binary() } + +// BinaryOK is the same as Binary, except it returns a boolean instead of +// panicking. +func (rv RawValue) BinaryOK() (subtype byte, data []byte, ok bool) { + return convertToCoreValue(rv).BinaryOK() +} + +// ObjectID returns the BSON objectid value the Value represents. It panics if the value is a BSON +// type other than objectid. +func (rv RawValue) ObjectID() primitive.ObjectID { return convertToCoreValue(rv).ObjectID() } + +// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of +// panicking. +func (rv RawValue) ObjectIDOK() (primitive.ObjectID, bool) { + return convertToCoreValue(rv).ObjectIDOK() +} + +// Boolean returns the boolean value the Value represents. It panics if the +// value is a BSON type other than boolean. +func (rv RawValue) Boolean() bool { return convertToCoreValue(rv).Boolean() } + +// BooleanOK is the same as Boolean, except it returns a boolean instead of +// panicking. +func (rv RawValue) BooleanOK() (bool, bool) { return convertToCoreValue(rv).BooleanOK() } + +// DateTime returns the BSON datetime value the Value represents as a +// unix timestamp. It panics if the value is a BSON type other than datetime. +func (rv RawValue) DateTime() int64 { return convertToCoreValue(rv).DateTime() } + +// DateTimeOK is the same as DateTime, except it returns a boolean instead of +// panicking. +func (rv RawValue) DateTimeOK() (int64, bool) { return convertToCoreValue(rv).DateTimeOK() } + +// Time returns the BSON datetime value the Value represents. It panics if the value is a BSON +// type other than datetime. +func (rv RawValue) Time() time.Time { return convertToCoreValue(rv).Time() } + +// TimeOK is the same as Time, except it returns a boolean instead of +// panicking. +func (rv RawValue) TimeOK() (time.Time, bool) { return convertToCoreValue(rv).TimeOK() } + +// Regex returns the BSON regex value the Value represents. It panics if the value is a BSON +// type other than regex. +func (rv RawValue) Regex() (pattern, options string) { return convertToCoreValue(rv).Regex() } + +// RegexOK is the same as Regex, except it returns a boolean instead of +// panicking. +func (rv RawValue) RegexOK() (pattern, options string, ok bool) { + return convertToCoreValue(rv).RegexOK() +} + +// DBPointer returns the BSON dbpointer value the Value represents. It panics if the value is a BSON +// type other than DBPointer. +func (rv RawValue) DBPointer() (string, primitive.ObjectID) { + return convertToCoreValue(rv).DBPointer() +} + +// DBPointerOK is the same as DBPoitner, except that it returns a boolean +// instead of panicking. +func (rv RawValue) DBPointerOK() (string, primitive.ObjectID, bool) { + return convertToCoreValue(rv).DBPointerOK() +} + +// JavaScript returns the BSON JavaScript code value the Value represents. It panics if the value is +// a BSON type other than JavaScript code. +func (rv RawValue) JavaScript() string { return convertToCoreValue(rv).JavaScript() } + +// JavaScriptOK is the same as Javascript, excepti that it returns a boolean +// instead of panicking. +func (rv RawValue) JavaScriptOK() (string, bool) { return convertToCoreValue(rv).JavaScriptOK() } + +// Symbol returns the BSON symbol value the Value represents. It panics if the value is a BSON +// type other than symbol. +func (rv RawValue) Symbol() string { return convertToCoreValue(rv).Symbol() } + +// SymbolOK is the same as Symbol, excepti that it returns a boolean +// instead of panicking. +func (rv RawValue) SymbolOK() (string, bool) { return convertToCoreValue(rv).SymbolOK() } + +// CodeWithScope returns the BSON JavaScript code with scope the Value represents. +// It panics if the value is a BSON type other than JavaScript code with scope. +func (rv RawValue) CodeWithScope() (string, Raw) { + code, scope := convertToCoreValue(rv).CodeWithScope() + return code, Raw(scope) +} + +// CodeWithScopeOK is the same as CodeWithScope, except that it returns a boolean instead of +// panicking. +func (rv RawValue) CodeWithScopeOK() (string, Raw, bool) { + code, scope, ok := convertToCoreValue(rv).CodeWithScopeOK() + return code, Raw(scope), ok +} + +// Int32 returns the int32 the Value represents. It panics if the value is a BSON type other than +// int32. +func (rv RawValue) Int32() int32 { return convertToCoreValue(rv).Int32() } + +// Int32OK is the same as Int32, except that it returns a boolean instead of +// panicking. +func (rv RawValue) Int32OK() (int32, bool) { return convertToCoreValue(rv).Int32OK() } + +// AsInt32 returns a BSON number as an int32. If the BSON type is not a numeric one, this method +// will panic. +// +// Deprecated: Use AsInt64 instead. If an int32 is required, convert the returned value to an int32 +// and perform any required overflow/underflow checking. +func (rv RawValue) AsInt32() int32 { return convertToCoreValue(rv).AsInt32() } + +// AsInt32OK is the same as AsInt32, except that it returns a boolean instead of +// panicking. +// +// Deprecated: Use AsInt64OK instead. If an int32 is required, convert the returned value to an +// int32 and perform any required overflow/underflow checking. +func (rv RawValue) AsInt32OK() (int32, bool) { return convertToCoreValue(rv).AsInt32OK() } + +// Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a +// BSON type other than timestamp. +func (rv RawValue) Timestamp() (t, i uint32) { return convertToCoreValue(rv).Timestamp() } + +// TimestampOK is the same as Timestamp, except that it returns a boolean +// instead of panicking. +func (rv RawValue) TimestampOK() (t, i uint32, ok bool) { return convertToCoreValue(rv).TimestampOK() } + +// Int64 returns the int64 the Value represents. It panics if the value is a BSON type other than +// int64. +func (rv RawValue) Int64() int64 { return convertToCoreValue(rv).Int64() } + +// Int64OK is the same as Int64, except that it returns a boolean instead of +// panicking. +func (rv RawValue) Int64OK() (int64, bool) { return convertToCoreValue(rv).Int64OK() } + +// AsInt64 returns a BSON number as an int64. If the BSON type is not a numeric one, this method +// will panic. +func (rv RawValue) AsInt64() int64 { return convertToCoreValue(rv).AsInt64() } + +// AsInt64OK is the same as AsInt64, except that it returns a boolean instead of +// panicking. +func (rv RawValue) AsInt64OK() (int64, bool) { return convertToCoreValue(rv).AsInt64OK() } + +// Decimal128 returns the decimal the Value represents. It panics if the value is a BSON type other than +// decimal. +func (rv RawValue) Decimal128() primitive.Decimal128 { return convertToCoreValue(rv).Decimal128() } + +// Decimal128OK is the same as Decimal128, except that it returns a boolean +// instead of panicking. +func (rv RawValue) Decimal128OK() (primitive.Decimal128, bool) { + return convertToCoreValue(rv).Decimal128OK() +} diff --git a/vendor/go.mongodb.org/mongo-driver/internal/logger/logger.go b/vendor/go.mongodb.org/mongo-driver/internal/logger/logger.go new file mode 100644 index 0000000000..c4053ea3df --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/internal/logger/logger.go @@ -0,0 +1,250 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package logger + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +// DefaultMaxDocumentLength is the default maximum number of bytes that can be +// logged for a stringified BSON document. +const DefaultMaxDocumentLength = 1000 + +// TruncationSuffix are trailling ellipsis "..." appended to a message to +// indicate to the user that truncation occurred. This constant does not count +// toward the max document length. +const TruncationSuffix = "..." + +const logSinkPathEnvVar = "MONGODB_LOG_PATH" +const maxDocumentLengthEnvVar = "MONGODB_LOG_MAX_DOCUMENT_LENGTH" + +// LogSink represents a logging implementation, this interface should be 1-1 +// with the exported "LogSink" interface in the mongo/options package. +type LogSink interface { + // Info logs a non-error message with the given key/value pairs. The + // level argument is provided for optional logging. + Info(level int, msg string, keysAndValues ...interface{}) + + // Error logs an error, with the given message and key/value pairs. + Error(err error, msg string, keysAndValues ...interface{}) +} + +// Logger represents the configuration for the internal logger. +type Logger struct { + ComponentLevels map[Component]Level // Log levels for each component. + Sink LogSink // LogSink for log printing. + MaxDocumentLength uint // Command truncation width. + logFile *os.File // File to write logs to. +} + +// New will construct a new logger. If any of the given options are the +// zero-value of the argument type, then the constructor will attempt to +// source the data from the environment. If the environment has not been set, +// then the constructor will the respective default values. +func New(sink LogSink, maxDocLen uint, compLevels map[Component]Level) (*Logger, error) { + logger := &Logger{ + ComponentLevels: selectComponentLevels(compLevels), + MaxDocumentLength: selectMaxDocumentLength(maxDocLen), + } + + sink, logFile, err := selectLogSink(sink) + if err != nil { + return nil, err + } + + logger.Sink = sink + logger.logFile = logFile + + return logger, nil +} + +// Close will close the logger's log file, if it exists. +func (logger *Logger) Close() error { + if logger.logFile != nil { + return logger.logFile.Close() + } + + return nil +} + +// LevelComponentEnabled will return true if the given LogLevel is enabled for +// the given LogComponent. +func (logger *Logger) LevelComponentEnabled(level Level, component Component) bool { + return logger.ComponentLevels[component] >= level +} + +// Print will synchronously print the given message to the configured LogSink. +// If the LogSink is nil, then this method will do nothing. Future work could be done to make +// this method asynchronous, see buffer management in libraries such as log4j. +func (logger *Logger) Print(level Level, component Component, msg string, keysAndValues ...interface{}) { + // If the level is not enabled for the component, then + // skip the message. + if !logger.LevelComponentEnabled(level, component) { + return + } + + // If the sink is nil, then skip the message. + if logger.Sink == nil { + return + } + + logger.Sink.Info(int(level)-DiffToInfo, msg, keysAndValues...) +} + +// Error logs an error, with the given message and key/value pairs. +// It functions similarly to Print, but may have unique behavior, and should be +// preferred for logging errors. +func (logger *Logger) Error(err error, msg string, keysAndValues ...interface{}) { + if logger.Sink == nil { + return + } + + logger.Sink.Error(err, msg, keysAndValues...) +} + +// selectMaxDocumentLength will return the integer value of the first non-zero +// function, with the user-defined function taking priority over the environment +// variables. For the environment, the function will attempt to get the value of +// "MONGODB_LOG_MAX_DOCUMENT_LENGTH" and parse it as an unsigned integer. If the +// environment variable is not set or is not an unsigned integer, then this +// function will return the default max document length. +func selectMaxDocumentLength(maxDocLen uint) uint { + if maxDocLen != 0 { + return maxDocLen + } + + maxDocLenEnv := os.Getenv(maxDocumentLengthEnvVar) + if maxDocLenEnv != "" { + maxDocLenEnvInt, err := strconv.ParseUint(maxDocLenEnv, 10, 32) + if err == nil { + return uint(maxDocLenEnvInt) + } + } + + return DefaultMaxDocumentLength +} + +const ( + logSinkPathStdout = "stdout" + logSinkPathStderr = "stderr" +) + +// selectLogSink will return the first non-nil LogSink, with the user-defined +// LogSink taking precedence over the environment-defined LogSink. If no LogSink +// is defined, then this function will return a LogSink that writes to stderr. +func selectLogSink(sink LogSink) (LogSink, *os.File, error) { + if sink != nil { + return sink, nil, nil + } + + path := os.Getenv(logSinkPathEnvVar) + lowerPath := strings.ToLower(path) + + if lowerPath == string(logSinkPathStderr) { + return NewIOSink(os.Stderr), nil, nil + } + + if lowerPath == string(logSinkPathStdout) { + return NewIOSink(os.Stdout), nil, nil + } + + if path != "" { + logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666) + if err != nil { + return nil, nil, fmt.Errorf("unable to open log file: %v", err) + } + + return NewIOSink(logFile), logFile, nil + } + + return NewIOSink(os.Stderr), nil, nil +} + +// selectComponentLevels returns a new map of LogComponents to LogLevels that is +// the result of merging the user-defined data with the environment, with the +// user-defined data taking priority. +func selectComponentLevels(componentLevels map[Component]Level) map[Component]Level { + selected := make(map[Component]Level) + + // Determine if the "MONGODB_LOG_ALL" environment variable is set. + var globalEnvLevel *Level + if all := os.Getenv(mongoDBLogAllEnvVar); all != "" { + level := ParseLevel(all) + globalEnvLevel = &level + } + + for envVar, component := range componentEnvVarMap { + // If the component already has a level, then skip it. + if _, ok := componentLevels[component]; ok { + selected[component] = componentLevels[component] + + continue + } + + // If the "MONGODB_LOG_ALL" environment variable is set, then + // set the level for the component to the value of the + // environment variable. + if globalEnvLevel != nil { + selected[component] = *globalEnvLevel + + continue + } + + // Otherwise, set the level for the component to the value of + // the environment variable. + selected[component] = ParseLevel(os.Getenv(envVar)) + } + + return selected +} + +// truncate will truncate a string to the given width, appending "..." to the +// end of the string if it is truncated. This routine is safe for multi-byte +// characters. +func truncate(str string, width uint) string { + if width == 0 { + return "" + } + + if len(str) <= int(width) { + return str + } + + // Truncate the byte slice of the string to the given width. + newStr := str[:width] + + // Check if the last byte is at the beginning of a multi-byte character. + // If it is, then remove the last byte. + if newStr[len(newStr)-1]&0xC0 == 0xC0 { + return newStr[:len(newStr)-1] + TruncationSuffix + } + + // Check if the last byte is in the middle of a multi-byte character. If + // it is, then step back until we find the beginning of the character. + if newStr[len(newStr)-1]&0xC0 == 0x80 { + for i := len(newStr) - 1; i >= 0; i-- { + if newStr[i]&0xC0 == 0xC0 { + return newStr[:i] + TruncationSuffix + } + } + } + + return newStr + TruncationSuffix +} + +// FormatMessage formats a BSON document for logging. The document is truncated +// to the given width. +func FormatMessage(msg string, width uint) string { + if len(msg) == 0 { + return "{}" + } + + return truncate(msg, width) +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go new file mode 100644 index 0000000000..76fe86f000 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/mongo/change_stream.go @@ -0,0 +1,712 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mongo + +import ( + "context" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal" + "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/operation" + "go.mongodb.org/mongo-driver/x/mongo/driver/session" +) + +var ( + // ErrMissingResumeToken indicates that a change stream notification from the server did not contain a resume token. + ErrMissingResumeToken = errors.New("cannot provide resume functionality when the resume token is missing") + // ErrNilCursor indicates that the underlying cursor for the change stream is nil. + ErrNilCursor = errors.New("cursor is nil") + + minResumableLabelWireVersion int32 = 9 // Wire version at which the server includes the resumable error label + networkErrorLabel = "NetworkError" + resumableErrorLabel = "ResumableChangeStreamError" + errorCursorNotFound int32 = 43 // CursorNotFound error code + + // Allowlist of error codes that are considered resumable. + resumableChangeStreamErrors = map[int32]struct{}{ + 6: {}, // HostUnreachable + 7: {}, // HostNotFound + 89: {}, // NetworkTimeout + 91: {}, // ShutdownInProgress + 189: {}, // PrimarySteppedDown + 262: {}, // ExceededTimeLimit + 9001: {}, // SocketException + 10107: {}, // NotPrimary + 11600: {}, // InterruptedAtShutdown + 11602: {}, // InterruptedDueToReplStateChange + 13435: {}, // NotPrimaryNoSecondaryOK + 13436: {}, // NotPrimaryOrSecondary + 63: {}, // StaleShardVersion + 150: {}, // StaleEpoch + 13388: {}, // StaleConfig + 234: {}, // RetryChangeStream + 133: {}, // FailedToSatisfyReadPreference + } +) + +// ChangeStream is used to iterate over a stream of events. Each event can be decoded into a Go type via the Decode +// method or accessed as raw BSON via the Current field. This type is not goroutine safe and must not be used +// concurrently by multiple goroutines. For more information about change streams, see +// https://www.mongodb.com/docs/manual/changeStreams/. +type ChangeStream struct { + // Current is the BSON bytes of the current event. This property is only valid until the next call to Next or + // TryNext. If continued access is required, a copy must be made. + Current bson.Raw + + aggregate *operation.Aggregate + pipelineSlice []bsoncore.Document + pipelineOptions map[string]bsoncore.Value + cursor changeStreamCursor + cursorOptions driver.CursorOptions + batch []bsoncore.Document + resumeToken bson.Raw + err error + sess *session.Client + client *Client + bsonOpts *options.BSONOptions + registry *bsoncodec.Registry + streamType StreamType + options *options.ChangeStreamOptions + selector description.ServerSelector + operationTime *primitive.Timestamp + wireVersion *description.VersionRange +} + +type changeStreamConfig struct { + readConcern *readconcern.ReadConcern + readPreference *readpref.ReadPref + client *Client + bsonOpts *options.BSONOptions + registry *bsoncodec.Registry + streamType StreamType + collectionName string + databaseName string + crypt driver.Crypt +} + +func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*ChangeStream, error) { + if ctx == nil { + ctx = context.Background() + } + + cs := &ChangeStream{ + client: config.client, + bsonOpts: config.bsonOpts, + registry: config.registry, + streamType: config.streamType, + options: options.MergeChangeStreamOptions(opts...), + selector: description.CompositeSelector([]description.ServerSelector{ + description.ReadPrefSelector(config.readPreference), + description.LatencySelector(config.client.localThreshold), + }), + cursorOptions: config.client.createBaseCursorOptions(), + } + + cs.sess = sessionFromContext(ctx) + if cs.sess == nil && cs.client.sessionPool != nil { + cs.sess = session.NewImplicitClientSession(cs.client.sessionPool, cs.client.id) + } + if cs.err = cs.client.validSession(cs.sess); cs.err != nil { + closeImplicitSession(cs.sess) + return nil, cs.Err() + } + + cs.aggregate = operation.NewAggregate(nil). + ReadPreference(config.readPreference).ReadConcern(config.readConcern). + Deployment(cs.client.deployment).ClusterClock(cs.client.clock). + CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone). + ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout) + + if cs.options.Collation != nil { + cs.aggregate.Collation(bsoncore.Document(cs.options.Collation.ToDocument())) + } + if comment := cs.options.Comment; comment != nil { + cs.aggregate.Comment(*comment) + + commentVal, err := marshalValue(comment, cs.bsonOpts, cs.registry) + if err != nil { + return nil, err + } + cs.cursorOptions.Comment = commentVal + } + if cs.options.BatchSize != nil { + cs.aggregate.BatchSize(*cs.options.BatchSize) + cs.cursorOptions.BatchSize = *cs.options.BatchSize + } + if cs.options.MaxAwaitTime != nil { + cs.cursorOptions.MaxTimeMS = int64(*cs.options.MaxAwaitTime / time.Millisecond) + } + if cs.options.Custom != nil { + // Marshal all custom options before passing to the initial aggregate. Return + // any errors from Marshaling. + customOptions := make(map[string]bsoncore.Value) + for optionName, optionValue := range cs.options.Custom { + bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue) + if err != nil { + cs.err = err + closeImplicitSession(cs.sess) + return nil, cs.Err() + } + optionValueBSON := bsoncore.Value{Type: bsonType, Data: bsonData} + customOptions[optionName] = optionValueBSON + } + cs.aggregate.CustomOptions(customOptions) + } + if cs.options.CustomPipeline != nil { + // Marshal all custom pipeline options before building pipeline slice. Return + // any errors from Marshaling. + cs.pipelineOptions = make(map[string]bsoncore.Value) + for optionName, optionValue := range cs.options.CustomPipeline { + bsonType, bsonData, err := bson.MarshalValueWithRegistry(cs.registry, optionValue) + if err != nil { + cs.err = err + closeImplicitSession(cs.sess) + return nil, cs.Err() + } + optionValueBSON := bsoncore.Value{Type: bsonType, Data: bsonData} + cs.pipelineOptions[optionName] = optionValueBSON + } + } + + switch cs.streamType { + case ClientStream: + cs.aggregate.Database("admin") + case DatabaseStream: + cs.aggregate.Database(config.databaseName) + case CollectionStream: + cs.aggregate.Collection(config.collectionName).Database(config.databaseName) + default: + closeImplicitSession(cs.sess) + return nil, fmt.Errorf("must supply a valid StreamType in config, instead of %v", cs.streamType) + } + + // When starting a change stream, cache startAfter as the first resume token if it is set. If not, cache + // resumeAfter. If neither is set, do not cache a resume token. + resumeToken := cs.options.StartAfter + if resumeToken == nil { + resumeToken = cs.options.ResumeAfter + } + var marshaledToken bson.Raw + if resumeToken != nil { + if marshaledToken, cs.err = bson.Marshal(resumeToken); cs.err != nil { + closeImplicitSession(cs.sess) + return nil, cs.Err() + } + } + cs.resumeToken = marshaledToken + + if cs.err = cs.buildPipelineSlice(pipeline); cs.err != nil { + closeImplicitSession(cs.sess) + return nil, cs.Err() + } + var pipelineArr bsoncore.Document + pipelineArr, cs.err = cs.pipelineToBSON() + cs.aggregate.Pipeline(pipelineArr) + + if cs.err = cs.executeOperation(ctx, false); cs.err != nil { + closeImplicitSession(cs.sess) + return nil, cs.Err() + } + + return cs, cs.Err() +} + +func (cs *ChangeStream) createOperationDeployment(server driver.Server, connection driver.Connection) driver.Deployment { + return &changeStreamDeployment{ + topologyKind: cs.client.deployment.Kind(), + server: server, + conn: connection, + } +} + +func (cs *ChangeStream) executeOperation(ctx context.Context, resuming bool) error { + var server driver.Server + var conn driver.Connection + + if server, cs.err = cs.client.deployment.SelectServer(ctx, cs.selector); cs.err != nil { + return cs.Err() + } + if conn, cs.err = server.Connection(ctx); cs.err != nil { + return cs.Err() + } + defer conn.Close() + cs.wireVersion = conn.Description().WireVersion + + cs.aggregate.Deployment(cs.createOperationDeployment(server, conn)) + + if resuming { + cs.replaceOptions(cs.wireVersion) + + csOptDoc, err := cs.createPipelineOptionsDoc() + if err != nil { + return err + } + pipIdx, pipDoc := bsoncore.AppendDocumentStart(nil) + pipDoc = bsoncore.AppendDocumentElement(pipDoc, "$changeStream", csOptDoc) + if pipDoc, cs.err = bsoncore.AppendDocumentEnd(pipDoc, pipIdx); cs.err != nil { + return cs.Err() + } + cs.pipelineSlice[0] = pipDoc + + var plArr bsoncore.Document + if plArr, cs.err = cs.pipelineToBSON(); cs.err != nil { + return cs.Err() + } + cs.aggregate.Pipeline(plArr) + } + + // If no deadline is set on the passed-in context, cs.client.timeout is set, and context is not already + // a Timeout context, honor cs.client.timeout in new Timeout context for change stream operation execution + // and potential retry. + if _, deadlineSet := ctx.Deadline(); !deadlineSet && cs.client.timeout != nil && !internal.IsTimeoutContext(ctx) { + newCtx, cancelFunc := internal.MakeTimeoutContext(ctx, *cs.client.timeout) + // Redefine ctx to be the new timeout-derived context. + ctx = newCtx + // Cancel the timeout-derived context at the end of executeOperation to avoid a context leak. + defer cancelFunc() + } + + // Execute the aggregate, retrying on retryable errors once (1) if retryable reads are enabled and + // infinitely (-1) if context is a Timeout context. + var retries int + if cs.client.retryReads { + retries = 1 + } + if internal.IsTimeoutContext(ctx) { + retries = -1 + } + + var err error +AggregateExecuteLoop: + for { + err = cs.aggregate.Execute(ctx) + // If no error or no retries remain, do not retry. + if err == nil || retries == 0 { + break AggregateExecuteLoop + } + + switch tt := err.(type) { + case driver.Error: + // If error is not retryable, do not retry. + if !tt.RetryableRead() { + break AggregateExecuteLoop + } + + // If error is retryable: subtract 1 from retries, redo server selection, checkout + // a connection, and restart loop. + retries-- + server, err = cs.client.deployment.SelectServer(ctx, cs.selector) + if err != nil { + break AggregateExecuteLoop + } + + conn.Close() + conn, err = server.Connection(ctx) + if err != nil { + break AggregateExecuteLoop + } + defer conn.Close() + + // Update the wire version with data from the new connection. + cs.wireVersion = conn.Description().WireVersion + + // Reset deployment. + cs.aggregate.Deployment(cs.createOperationDeployment(server, conn)) + default: + // Do not retry if error is not a driver error. + break AggregateExecuteLoop + } + } + if err != nil { + cs.err = replaceErrors(err) + return cs.err + } + + cr := cs.aggregate.ResultCursorResponse() + cr.Server = server + + cs.cursor, cs.err = driver.NewBatchCursor(cr, cs.sess, cs.client.clock, cs.cursorOptions) + if cs.err = replaceErrors(cs.err); cs.err != nil { + return cs.Err() + } + + cs.updatePbrtFromCommand() + if cs.options.StartAtOperationTime == nil && cs.options.ResumeAfter == nil && + cs.options.StartAfter == nil && cs.wireVersion.Max >= 7 && + cs.emptyBatch() && cs.resumeToken == nil { + cs.operationTime = cs.sess.OperationTime + } + + return cs.Err() +} + +// Updates the post batch resume token after a successful aggregate or getMore operation. +func (cs *ChangeStream) updatePbrtFromCommand() { + // Only cache the pbrt if an empty batch was returned and a pbrt was included + if pbrt := cs.cursor.PostBatchResumeToken(); cs.emptyBatch() && pbrt != nil { + cs.resumeToken = bson.Raw(pbrt) + } +} + +func (cs *ChangeStream) storeResumeToken() error { + // If cs.Current is the last document in the batch and a pbrt is included, cache the pbrt + // Otherwise, cache the _id of the document + var tokenDoc bson.Raw + if len(cs.batch) == 0 { + if pbrt := cs.cursor.PostBatchResumeToken(); pbrt != nil { + tokenDoc = bson.Raw(pbrt) + } + } + + if tokenDoc == nil { + var ok bool + tokenDoc, ok = cs.Current.Lookup("_id").DocumentOK() + if !ok { + _ = cs.Close(context.Background()) + return ErrMissingResumeToken + } + } + + cs.resumeToken = tokenDoc + return nil +} + +func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error { + val := reflect.ValueOf(pipeline) + if !val.IsValid() || !(val.Kind() == reflect.Slice) { + cs.err = errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") + return cs.err + } + + cs.pipelineSlice = make([]bsoncore.Document, 0, val.Len()+1) + + csIdx, csDoc := bsoncore.AppendDocumentStart(nil) + + csDocTemp, err := cs.createPipelineOptionsDoc() + if err != nil { + return err + } + csDoc = bsoncore.AppendDocumentElement(csDoc, "$changeStream", csDocTemp) + csDoc, cs.err = bsoncore.AppendDocumentEnd(csDoc, csIdx) + if cs.err != nil { + return cs.err + } + cs.pipelineSlice = append(cs.pipelineSlice, csDoc) + + for i := 0; i < val.Len(); i++ { + var elem []byte + elem, cs.err = marshal(val.Index(i).Interface(), cs.bsonOpts, cs.registry) + if cs.err != nil { + return cs.err + } + + cs.pipelineSlice = append(cs.pipelineSlice, elem) + } + + return cs.err +} + +func (cs *ChangeStream) createPipelineOptionsDoc() (bsoncore.Document, error) { + plDocIdx, plDoc := bsoncore.AppendDocumentStart(nil) + + if cs.streamType == ClientStream { + plDoc = bsoncore.AppendBooleanElement(plDoc, "allChangesForCluster", true) + } + + if cs.options.FullDocument != nil && *cs.options.FullDocument != options.Default { + plDoc = bsoncore.AppendStringElement(plDoc, "fullDocument", string(*cs.options.FullDocument)) + } + + if cs.options.FullDocumentBeforeChange != nil { + plDoc = bsoncore.AppendStringElement(plDoc, "fullDocumentBeforeChange", string(*cs.options.FullDocumentBeforeChange)) + } + + if cs.options.ResumeAfter != nil { + var raDoc bsoncore.Document + raDoc, cs.err = marshal(cs.options.ResumeAfter, cs.bsonOpts, cs.registry) + if cs.err != nil { + return nil, cs.err + } + + plDoc = bsoncore.AppendDocumentElement(plDoc, "resumeAfter", raDoc) + } + + if cs.options.ShowExpandedEvents != nil { + plDoc = bsoncore.AppendBooleanElement(plDoc, "showExpandedEvents", *cs.options.ShowExpandedEvents) + } + + if cs.options.StartAfter != nil { + var saDoc bsoncore.Document + saDoc, cs.err = marshal(cs.options.StartAfter, cs.bsonOpts, cs.registry) + if cs.err != nil { + return nil, cs.err + } + + plDoc = bsoncore.AppendDocumentElement(plDoc, "startAfter", saDoc) + } + + if cs.options.StartAtOperationTime != nil { + plDoc = bsoncore.AppendTimestampElement(plDoc, "startAtOperationTime", cs.options.StartAtOperationTime.T, cs.options.StartAtOperationTime.I) + } + + // Append custom pipeline options. + for optionName, optionValue := range cs.pipelineOptions { + plDoc = bsoncore.AppendValueElement(plDoc, optionName, optionValue) + } + + if plDoc, cs.err = bsoncore.AppendDocumentEnd(plDoc, plDocIdx); cs.err != nil { + return nil, cs.err + } + + return plDoc, nil +} + +func (cs *ChangeStream) pipelineToBSON() (bsoncore.Document, error) { + pipelineDocIdx, pipelineArr := bsoncore.AppendArrayStart(nil) + for i, doc := range cs.pipelineSlice { + pipelineArr = bsoncore.AppendDocumentElement(pipelineArr, strconv.Itoa(i), doc) + } + if pipelineArr, cs.err = bsoncore.AppendArrayEnd(pipelineArr, pipelineDocIdx); cs.err != nil { + return nil, cs.err + } + return pipelineArr, cs.err +} + +func (cs *ChangeStream) replaceOptions(wireVersion *description.VersionRange) { + // Cached resume token: use the resume token as the resumeAfter option and set no other resume options + if cs.resumeToken != nil { + cs.options.SetResumeAfter(cs.resumeToken) + cs.options.SetStartAfter(nil) + cs.options.SetStartAtOperationTime(nil) + return + } + + // No cached resume token but cached operation time: use the operation time as the startAtOperationTime option and + // set no other resume options + if (cs.sess.OperationTime != nil || cs.options.StartAtOperationTime != nil) && wireVersion.Max >= 7 { + opTime := cs.options.StartAtOperationTime + if cs.operationTime != nil { + opTime = cs.sess.OperationTime + } + + cs.options.SetStartAtOperationTime(opTime) + cs.options.SetResumeAfter(nil) + cs.options.SetStartAfter(nil) + return + } + + // No cached resume token or operation time: set none of the resume options + cs.options.SetResumeAfter(nil) + cs.options.SetStartAfter(nil) + cs.options.SetStartAtOperationTime(nil) +} + +// ID returns the ID for this change stream, or 0 if the cursor has been closed or exhausted. +func (cs *ChangeStream) ID() int64 { + if cs.cursor == nil { + return 0 + } + return cs.cursor.ID() +} + +// Decode will unmarshal the current event document into val and return any errors from the unmarshalling process +// without any modification. If val is nil or is a typed nil, an error will be returned. +func (cs *ChangeStream) Decode(val interface{}) error { + if cs.cursor == nil { + return ErrNilCursor + } + + dec, err := getDecoder(cs.Current, cs.bsonOpts, cs.registry) + if err != nil { + return fmt.Errorf("error configuring BSON decoder: %w", err) + } + return dec.Decode(val) +} + +// Err returns the last error seen by the change stream, or nil if no errors has occurred. +func (cs *ChangeStream) Err() error { + if cs.err != nil { + return replaceErrors(cs.err) + } + if cs.cursor == nil { + return nil + } + + return replaceErrors(cs.cursor.Err()) +} + +// Close closes this change stream and the underlying cursor. Next and TryNext must not be called after Close has been +// called. Close is idempotent. After the first call, any subsequent calls will not change the state. +func (cs *ChangeStream) Close(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + + defer closeImplicitSession(cs.sess) + + if cs.cursor == nil { + return nil // cursor is already closed + } + + cs.err = replaceErrors(cs.cursor.Close(ctx)) + cs.cursor = nil + return cs.Err() +} + +// ResumeToken returns the last cached resume token for this change stream, or nil if a resume token has not been +// stored. +func (cs *ChangeStream) ResumeToken() bson.Raw { + return cs.resumeToken +} + +// Next gets the next event for this change stream. It returns true if there were no errors and the next event document +// is available. +// +// Next blocks until an event is available, an error occurs, or ctx expires. If ctx expires, the error +// will be set to ctx.Err(). In an error case, Next will return false. +// +// If Next returns false, subsequent calls will also return false. +func (cs *ChangeStream) Next(ctx context.Context) bool { + return cs.next(ctx, false) +} + +// TryNext attempts to get the next event for this change stream. It returns true if there were no errors and the next +// event document is available. +// +// TryNext returns false if the change stream is closed by the server, an error occurs when getting changes from the +// server, the next change is not yet available, or ctx expires. If ctx expires, the error will be set to ctx.Err(). +// +// If TryNext returns false and an error occurred or the change stream was closed +// (i.e. cs.Err() != nil || cs.ID() == 0), subsequent attempts will also return false. Otherwise, it is safe to call +// TryNext again until a change is available. +// +// This method requires driver version >= 1.2.0. +func (cs *ChangeStream) TryNext(ctx context.Context) bool { + return cs.next(ctx, true) +} + +func (cs *ChangeStream) next(ctx context.Context, nonBlocking bool) bool { + // return false right away if the change stream has already errored or if cursor is closed. + if cs.err != nil { + return false + } + + if ctx == nil { + ctx = context.Background() + } + + if len(cs.batch) == 0 { + cs.loopNext(ctx, nonBlocking) + if cs.err != nil { + cs.err = replaceErrors(cs.err) + return false + } + if len(cs.batch) == 0 { + return false + } + } + + // successfully got non-empty batch + cs.Current = bson.Raw(cs.batch[0]) + cs.batch = cs.batch[1:] + if cs.err = cs.storeResumeToken(); cs.err != nil { + return false + } + return true +} + +func (cs *ChangeStream) loopNext(ctx context.Context, nonBlocking bool) { + for { + if cs.cursor == nil { + return + } + + if cs.cursor.Next(ctx) { + // non-empty batch returned + cs.batch, cs.err = cs.cursor.Batch().Documents() + return + } + + cs.err = replaceErrors(cs.cursor.Err()) + if cs.err == nil { + // Check if cursor is alive + if cs.ID() == 0 { + return + } + + // If a getMore was done but the batch was empty, the batch cursor will return false with no error. + // Update the tracked resume token to catch the post batch resume token from the server response. + cs.updatePbrtFromCommand() + if nonBlocking { + // stop after a successful getMore, even though the batch was empty + return + } + continue // loop getMore until a non-empty batch is returned or an error occurs + } + + if !cs.isResumableError() { + return + } + + // ignore error from cursor close because if the cursor is deleted or errors we tried to close it and will remake and try to get next batch + _ = cs.cursor.Close(ctx) + if cs.err = cs.executeOperation(ctx, true); cs.err != nil { + return + } + } +} + +func (cs *ChangeStream) isResumableError() bool { + commandErr, ok := cs.err.(CommandError) + if !ok || commandErr.HasErrorLabel(networkErrorLabel) { + // All non-server errors or network errors are resumable. + return true + } + + if commandErr.Code == errorCursorNotFound { + return true + } + + // For wire versions 9 and above, a server error is resumable if it has the ResumableChangeStreamError label. + if cs.wireVersion != nil && cs.wireVersion.Includes(minResumableLabelWireVersion) { + return commandErr.HasErrorLabel(resumableErrorLabel) + } + + // For wire versions below 9, a server error is resumable if its code is on the allowlist. + _, resumable := resumableChangeStreamErrors[commandErr.Code] + return resumable +} + +// Returns true if the underlying cursor's batch is empty +func (cs *ChangeStream) emptyBatch() bool { + return cs.cursor.Batch().Empty() +} + +// StreamType represents the cluster type against which a ChangeStream was created. +type StreamType uint8 + +// These constants represent valid change stream types. A change stream can be initialized over a collection, all +// collections in a database, or over a cluster. +const ( + CollectionStream StreamType = iota + DatabaseStream + ClientStream +) diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/options/loggeroptions.go b/vendor/go.mongodb.org/mongo-driver/mongo/options/loggeroptions.go new file mode 100644 index 0000000000..4a33e449a5 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/mongo/options/loggeroptions.go @@ -0,0 +1,103 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package options + +import ( + "go.mongodb.org/mongo-driver/internal/logger" +) + +// LogLevel is an enumeration representing the supported log severity levels. +type LogLevel int + +const ( + // LogLevelInfo enables logging of informational messages. These logs + // are high-level information about normal driver behavior. + LogLevelInfo LogLevel = LogLevel(logger.LevelInfo) + + // LogLevelDebug enables logging of debug messages. These logs can be + // voluminous and are intended for detailed information that may be + // helpful when debugging an application. + LogLevelDebug LogLevel = LogLevel(logger.LevelDebug) +) + +// LogComponent is an enumeration representing the "components" which can be +// logged against. A LogLevel can be configured on a per-component basis. +type LogComponent int + +const ( + // LogComponentAll enables logging for all components. + LogComponentAll LogComponent = LogComponent(logger.ComponentAll) + + // LogComponentCommand enables command monitor logging. + LogComponentCommand LogComponent = LogComponent(logger.ComponentCommand) + + // LogComponentTopology enables topology logging. + LogComponentTopology LogComponent = LogComponent(logger.ComponentTopology) + + // LogComponentServerSelection enables server selection logging. + LogComponentServerSelection LogComponent = LogComponent(logger.ComponentServerSelection) + + // LogComponentConnection enables connection services logging. + LogComponentConnection LogComponent = LogComponent(logger.ComponentConnection) +) + +// LogSink is an interface that can be implemented to provide a custom sink for +// the driver's logs. +type LogSink interface { + // Info logs a non-error message with the given key/value pairs. This + // method will only be called if the provided level has been defined + // for a component in the LoggerOptions. + Info(level int, message string, keysAndValues ...interface{}) + + // Error logs an error message with the given key/value pairs + Error(err error, message string, keysAndValues ...interface{}) +} + +// LoggerOptions represent options used to configure Logging in the Go Driver. +type LoggerOptions struct { + // ComponentLevels is a map of LogComponent to LogLevel. The LogLevel + // for a given LogComponent will be used to determine if a log message + // should be logged. + ComponentLevels map[LogComponent]LogLevel + + // Sink is the LogSink that will be used to log messages. If this is + // nil, the driver will use the standard logging library. + Sink LogSink + + // MaxDocumentLength is the maximum length of a document to be logged. + // If the underlying document is larger than this value, it will be + // truncated and appended with an ellipses "...". + MaxDocumentLength uint +} + +// Logger creates a new LoggerOptions instance. +func Logger() *LoggerOptions { + return &LoggerOptions{ + ComponentLevels: map[LogComponent]LogLevel{}, + } +} + +// SetComponentLevel sets the LogLevel value for a LogComponent. +func (opts *LoggerOptions) SetComponentLevel(component LogComponent, level LogLevel) *LoggerOptions { + opts.ComponentLevels[component] = level + + return opts +} + +// SetMaxDocumentLength sets the maximum length of a document to be logged. +func (opts *LoggerOptions) SetMaxDocumentLength(maxDocumentLength uint) *LoggerOptions { + opts.MaxDocumentLength = maxDocumentLength + + return opts +} + +// SetSink sets the LogSink to use for logging. +func (opts *LoggerOptions) SetSink(sink LogSink) *LoggerOptions { + opts.Sink = sink + + return opts +} diff --git a/vendor/go.mongodb.org/mongo-driver/mongo/readconcern/readconcern.go b/vendor/go.mongodb.org/mongo-driver/mongo/readconcern/readconcern.go new file mode 100644 index 0000000000..987f416055 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/mongo/readconcern/readconcern.go @@ -0,0 +1,123 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package readconcern defines read concerns for MongoDB operations. +// +// For more information about MongoDB read concerns, see +// https://www.mongodb.com/docs/manual/reference/read-concern/ +package readconcern // import "go.mongodb.org/mongo-driver/mongo/readconcern" + +import ( + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +// A ReadConcern defines a MongoDB read concern, which allows you to control the consistency and +// isolation properties of the data read from replica sets and replica set shards. +// +// For more information about MongoDB read concerns, see +// https://www.mongodb.com/docs/manual/reference/read-concern/ +type ReadConcern struct { + Level string +} + +// Option is an option to provide when creating a ReadConcern. +// +// Deprecated: Use the ReadConcern literal declaration instead. For example: +// +// &readconcern.ReadConcern{Level: "local"} +type Option func(concern *ReadConcern) + +// Level creates an option that sets the level of a ReadConcern. +// +// Deprecated: Use the ReadConcern literal declaration instead. For example: +// +// &readconcern.ReadConcern{Level: "local"} +func Level(level string) Option { + return func(concern *ReadConcern) { + concern.Level = level + } +} + +// Local returns a ReadConcern that requests data from the instance with no guarantee that the data +// has been written to a majority of the replica set members (i.e. may be rolled back). +// +// For more information about read concern "local", see +// https://www.mongodb.com/docs/manual/reference/read-concern-local/ +func Local() *ReadConcern { + return New(Level("local")) +} + +// Majority returns a ReadConcern that requests data that has been acknowledged by a majority of the +// replica set members (i.e. the documents read are durable and guaranteed not to roll back). +// +// For more information about read concern "majority", see +// https://www.mongodb.com/docs/manual/reference/read-concern-majority/ +func Majority() *ReadConcern { + return New(Level("majority")) +} + +// Linearizable returns a ReadConcern that requests data that reflects all successful +// majority-acknowledged writes that completed prior to the start of the read operation. +// +// For more information about read concern "linearizable", see +// https://www.mongodb.com/docs/manual/reference/read-concern-linearizable/ +func Linearizable() *ReadConcern { + return New(Level("linearizable")) +} + +// Available returns a ReadConcern that requests data from an instance with no guarantee that the +// data has been written to a majority of the replica set members (i.e. may be rolled back). +// +// For more information about read concern "available", see +// https://www.mongodb.com/docs/manual/reference/read-concern-available/ +func Available() *ReadConcern { + return New(Level("available")) +} + +// Snapshot returns a ReadConcern that requests majority-committed data as it appears across shards +// from a specific single point in time in the recent past. +// +// For more information about read concern "snapshot", see +// https://www.mongodb.com/docs/manual/reference/read-concern-snapshot/ +func Snapshot() *ReadConcern { + return New(Level("snapshot")) +} + +// New constructs a new read concern from the given string. +// +// Deprecated: Use the ReadConcern literal declaration instead. For example: +// +// &readconcern.ReadConcern{Level: "local"} +func New(options ...Option) *ReadConcern { + concern := &ReadConcern{} + + for _, option := range options { + option(concern) + } + + return concern +} + +// MarshalBSONValue implements the bson.ValueMarshaler interface. +// +// Deprecated: Marshaling a ReadConcern to BSON will not be supported in Go Driver 2.0. +func (rc *ReadConcern) MarshalBSONValue() (bsontype.Type, []byte, error) { + var elems []byte + + if len(rc.Level) > 0 { + elems = bsoncore.AppendStringElement(elems, "level", rc.Level) + } + + return bsontype.EmbeddedDocument, bsoncore.BuildDocument(nil, elems), nil +} + +// GetLevel returns the read concern level. +// +// Deprecated: Use the ReadConcern.Level field instead. +func (rc *ReadConcern) GetLevel() string { + return rc.Level +} diff --git a/vendor/go.mongodb.org/mongo-driver/version/version.go b/vendor/go.mongodb.org/mongo-driver/version/version.go new file mode 100644 index 0000000000..354548fe7a --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/version/version.go @@ -0,0 +1,11 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// Package version defines the Go Driver version. +package version // import "go.mongodb.org/mongo-driver/version" + +// Driver is the current version of the driver. +var Driver = "v1.12.0" diff --git a/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bsoncore.go b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bsoncore.go new file mode 100644 index 0000000000..94d479428f --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/bsoncore.go @@ -0,0 +1,843 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bsoncore // import "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + +import ( + "bytes" + "fmt" + "math" + "strconv" + "strings" + "time" + + "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +const ( + // EmptyDocumentLength is the length of a document that has been started/ended but has no elements. + EmptyDocumentLength = 5 + // nullTerminator is a string version of the 0 byte that is appended at the end of cstrings. + nullTerminator = string(byte(0)) + invalidKeyPanicMsg = "BSON element keys cannot contain null bytes" + invalidRegexPanicMsg = "BSON regex values cannot contain null bytes" +) + +// AppendType will append t to dst and return the extended buffer. +func AppendType(dst []byte, t bsontype.Type) []byte { return append(dst, byte(t)) } + +// AppendKey will append key to dst and return the extended buffer. +func AppendKey(dst []byte, key string) []byte { return append(dst, key+nullTerminator...) } + +// AppendHeader will append Type t and key to dst and return the extended +// buffer. +func AppendHeader(dst []byte, t bsontype.Type, key string) []byte { + if !isValidCString(key) { + panic(invalidKeyPanicMsg) + } + + dst = AppendType(dst, t) + dst = append(dst, key...) + return append(dst, 0x00) + // return append(AppendType(dst, t), key+string(0x00)...) +} + +// TODO(skriptble): All of the Read* functions should return src resliced to start just after what was read. + +// ReadType will return the first byte of the provided []byte as a type. If +// there is no available byte, false is returned. +func ReadType(src []byte) (bsontype.Type, []byte, bool) { + if len(src) < 1 { + return 0, src, false + } + return bsontype.Type(src[0]), src[1:], true +} + +// ReadKey will read a key from src. The 0x00 byte will not be present +// in the returned string. If there are not enough bytes available, false is +// returned. +func ReadKey(src []byte) (string, []byte, bool) { return readcstring(src) } + +// ReadKeyBytes will read a key from src as bytes. The 0x00 byte will +// not be present in the returned string. If there are not enough bytes +// available, false is returned. +func ReadKeyBytes(src []byte) ([]byte, []byte, bool) { return readcstringbytes(src) } + +// ReadHeader will read a type byte and a key from src. If both of these +// values cannot be read, false is returned. +func ReadHeader(src []byte) (t bsontype.Type, key string, rem []byte, ok bool) { + t, rem, ok = ReadType(src) + if !ok { + return 0, "", src, false + } + key, rem, ok = ReadKey(rem) + if !ok { + return 0, "", src, false + } + + return t, key, rem, true +} + +// ReadHeaderBytes will read a type and a key from src and the remainder of the bytes +// are returned as rem. If either the type or key cannot be red, ok will be false. +func ReadHeaderBytes(src []byte) (header []byte, rem []byte, ok bool) { + if len(src) < 1 { + return nil, src, false + } + idx := bytes.IndexByte(src[1:], 0x00) + if idx == -1 { + return nil, src, false + } + return src[:idx], src[idx+1:], true +} + +// ReadElement reads the next full element from src. It returns the element, the remaining bytes in +// the slice, and a boolean indicating if the read was successful. +func ReadElement(src []byte) (Element, []byte, bool) { + if len(src) < 1 { + return nil, src, false + } + t := bsontype.Type(src[0]) + idx := bytes.IndexByte(src[1:], 0x00) + if idx == -1 { + return nil, src, false + } + length, ok := valueLength(src[idx+2:], t) // We add 2 here because we called IndexByte with src[1:] + if !ok { + return nil, src, false + } + elemLength := 1 + idx + 1 + int(length) + if elemLength > len(src) { + return nil, src, false + } + if elemLength < 0 { + return nil, src, false + } + return src[:elemLength], src[elemLength:], true +} + +// AppendValueElement appends value to dst as an element using key as the element's key. +func AppendValueElement(dst []byte, key string, value Value) []byte { + dst = AppendHeader(dst, value.Type, key) + dst = append(dst, value.Data...) + return dst +} + +// ReadValue reads the next value as the provided types and returns a Value, the remaining bytes, +// and a boolean indicating if the read was successful. +func ReadValue(src []byte, t bsontype.Type) (Value, []byte, bool) { + data, rem, ok := readValue(src, t) + if !ok { + return Value{}, src, false + } + return Value{Type: t, Data: data}, rem, true +} + +// AppendDouble will append f to dst and return the extended buffer. +func AppendDouble(dst []byte, f float64) []byte { + return appendu64(dst, math.Float64bits(f)) +} + +// AppendDoubleElement will append a BSON double element using key and f to dst +// and return the extended buffer. +func AppendDoubleElement(dst []byte, key string, f float64) []byte { + return AppendDouble(AppendHeader(dst, bsontype.Double, key), f) +} + +// ReadDouble will read a float64 from src. If there are not enough bytes it +// will return false. +func ReadDouble(src []byte) (float64, []byte, bool) { + bits, src, ok := readu64(src) + if !ok { + return 0, src, false + } + return math.Float64frombits(bits), src, true +} + +// AppendString will append s to dst and return the extended buffer. +func AppendString(dst []byte, s string) []byte { + return appendstring(dst, s) +} + +// AppendStringElement will append a BSON string element using key and val to dst +// and return the extended buffer. +func AppendStringElement(dst []byte, key, val string) []byte { + return AppendString(AppendHeader(dst, bsontype.String, key), val) +} + +// ReadString will read a string from src. If there are not enough bytes it +// will return false. +func ReadString(src []byte) (string, []byte, bool) { + return readstring(src) +} + +// AppendDocumentStart reserves a document's length and returns the index where the length begins. +// This index can later be used to write the length of the document. +func AppendDocumentStart(dst []byte) (index int32, b []byte) { + // TODO(skriptble): We really need AppendDocumentStart and AppendDocumentEnd. AppendDocumentStart would handle calling + // TODO ReserveLength and providing the index of the start of the document. AppendDocumentEnd would handle taking that + // TODO start index, adding the null byte, calculating the length, and filling in the length at the start of the + // TODO document. + return ReserveLength(dst) +} + +// AppendDocumentStartInline functions the same as AppendDocumentStart but takes a pointer to the +// index int32 which allows this function to be used inline. +func AppendDocumentStartInline(dst []byte, index *int32) []byte { + idx, doc := AppendDocumentStart(dst) + *index = idx + return doc +} + +// AppendDocumentElementStart writes a document element header and then reserves the length bytes. +func AppendDocumentElementStart(dst []byte, key string) (index int32, b []byte) { + return AppendDocumentStart(AppendHeader(dst, bsontype.EmbeddedDocument, key)) +} + +// AppendDocumentEnd writes the null byte for a document and updates the length of the document. +// The index should be the beginning of the document's length bytes. +func AppendDocumentEnd(dst []byte, index int32) ([]byte, error) { + if int(index) > len(dst)-4 { + return dst, fmt.Errorf("not enough bytes available after index to write length") + } + dst = append(dst, 0x00) + dst = UpdateLength(dst, index, int32(len(dst[index:]))) + return dst, nil +} + +// AppendDocument will append doc to dst and return the extended buffer. +func AppendDocument(dst []byte, doc []byte) []byte { return append(dst, doc...) } + +// AppendDocumentElement will append a BSON embedded document element using key +// and doc to dst and return the extended buffer. +func AppendDocumentElement(dst []byte, key string, doc []byte) []byte { + return AppendDocument(AppendHeader(dst, bsontype.EmbeddedDocument, key), doc) +} + +// BuildDocument will create a document with the given slice of elements and will append +// it to dst and return the extended buffer. +func BuildDocument(dst []byte, elems ...[]byte) []byte { + idx, dst := ReserveLength(dst) + for _, elem := range elems { + dst = append(dst, elem...) + } + dst = append(dst, 0x00) + dst = UpdateLength(dst, idx, int32(len(dst[idx:]))) + return dst +} + +// BuildDocumentValue creates an Embedded Document value from the given elements. +func BuildDocumentValue(elems ...[]byte) Value { + return Value{Type: bsontype.EmbeddedDocument, Data: BuildDocument(nil, elems...)} +} + +// BuildDocumentElement will append a BSON embedded document elemnt using key and the provided +// elements and return the extended buffer. +func BuildDocumentElement(dst []byte, key string, elems ...[]byte) []byte { + return BuildDocument(AppendHeader(dst, bsontype.EmbeddedDocument, key), elems...) +} + +// BuildDocumentFromElements is an alaias for the BuildDocument function. +var BuildDocumentFromElements = BuildDocument + +// ReadDocument will read a document from src. If there are not enough bytes it +// will return false. +func ReadDocument(src []byte) (doc Document, rem []byte, ok bool) { return readLengthBytes(src) } + +// AppendArrayStart appends the length bytes to an array and then returns the index of the start +// of those length bytes. +func AppendArrayStart(dst []byte) (index int32, b []byte) { return ReserveLength(dst) } + +// AppendArrayElementStart appends an array element header and then the length bytes for an array, +// returning the index where the length starts. +func AppendArrayElementStart(dst []byte, key string) (index int32, b []byte) { + return AppendArrayStart(AppendHeader(dst, bsontype.Array, key)) +} + +// AppendArrayEnd appends the null byte to an array and calculates the length, inserting that +// calculated length starting at index. +func AppendArrayEnd(dst []byte, index int32) ([]byte, error) { return AppendDocumentEnd(dst, index) } + +// AppendArray will append arr to dst and return the extended buffer. +func AppendArray(dst []byte, arr []byte) []byte { return append(dst, arr...) } + +// AppendArrayElement will append a BSON array element using key and arr to dst +// and return the extended buffer. +func AppendArrayElement(dst []byte, key string, arr []byte) []byte { + return AppendArray(AppendHeader(dst, bsontype.Array, key), arr) +} + +// BuildArray will append a BSON array to dst built from values. +func BuildArray(dst []byte, values ...Value) []byte { + idx, dst := ReserveLength(dst) + for pos, val := range values { + dst = AppendValueElement(dst, strconv.Itoa(pos), val) + } + dst = append(dst, 0x00) + dst = UpdateLength(dst, idx, int32(len(dst[idx:]))) + return dst +} + +// BuildArrayElement will create an array element using the provided values. +func BuildArrayElement(dst []byte, key string, values ...Value) []byte { + return BuildArray(AppendHeader(dst, bsontype.Array, key), values...) +} + +// ReadArray will read an array from src. If there are not enough bytes it +// will return false. +func ReadArray(src []byte) (arr Array, rem []byte, ok bool) { return readLengthBytes(src) } + +// AppendBinary will append subtype and b to dst and return the extended buffer. +func AppendBinary(dst []byte, subtype byte, b []byte) []byte { + if subtype == 0x02 { + return appendBinarySubtype2(dst, subtype, b) + } + dst = append(appendLength(dst, int32(len(b))), subtype) + return append(dst, b...) +} + +// AppendBinaryElement will append a BSON binary element using key, subtype, and +// b to dst and return the extended buffer. +func AppendBinaryElement(dst []byte, key string, subtype byte, b []byte) []byte { + return AppendBinary(AppendHeader(dst, bsontype.Binary, key), subtype, b) +} + +// ReadBinary will read a subtype and bin from src. If there are not enough bytes it +// will return false. +func ReadBinary(src []byte) (subtype byte, bin []byte, rem []byte, ok bool) { + length, rem, ok := ReadLength(src) + if !ok { + return 0x00, nil, src, false + } + if len(rem) < 1 { // subtype + return 0x00, nil, src, false + } + subtype, rem = rem[0], rem[1:] + + if len(rem) < int(length) { + return 0x00, nil, src, false + } + + if subtype == 0x02 { + length, rem, ok = ReadLength(rem) + if !ok || len(rem) < int(length) { + return 0x00, nil, src, false + } + } + + return subtype, rem[:length], rem[length:], true +} + +// AppendUndefinedElement will append a BSON undefined element using key to dst +// and return the extended buffer. +func AppendUndefinedElement(dst []byte, key string) []byte { + return AppendHeader(dst, bsontype.Undefined, key) +} + +// AppendObjectID will append oid to dst and return the extended buffer. +func AppendObjectID(dst []byte, oid primitive.ObjectID) []byte { return append(dst, oid[:]...) } + +// AppendObjectIDElement will append a BSON ObjectID element using key and oid to dst +// and return the extended buffer. +func AppendObjectIDElement(dst []byte, key string, oid primitive.ObjectID) []byte { + return AppendObjectID(AppendHeader(dst, bsontype.ObjectID, key), oid) +} + +// ReadObjectID will read an ObjectID from src. If there are not enough bytes it +// will return false. +func ReadObjectID(src []byte) (primitive.ObjectID, []byte, bool) { + if len(src) < 12 { + return primitive.ObjectID{}, src, false + } + var oid primitive.ObjectID + copy(oid[:], src[0:12]) + return oid, src[12:], true +} + +// AppendBoolean will append b to dst and return the extended buffer. +func AppendBoolean(dst []byte, b bool) []byte { + if b { + return append(dst, 0x01) + } + return append(dst, 0x00) +} + +// AppendBooleanElement will append a BSON boolean element using key and b to dst +// and return the extended buffer. +func AppendBooleanElement(dst []byte, key string, b bool) []byte { + return AppendBoolean(AppendHeader(dst, bsontype.Boolean, key), b) +} + +// ReadBoolean will read a bool from src. If there are not enough bytes it +// will return false. +func ReadBoolean(src []byte) (bool, []byte, bool) { + if len(src) < 1 { + return false, src, false + } + + return src[0] == 0x01, src[1:], true +} + +// AppendDateTime will append dt to dst and return the extended buffer. +func AppendDateTime(dst []byte, dt int64) []byte { return appendi64(dst, dt) } + +// AppendDateTimeElement will append a BSON datetime element using key and dt to dst +// and return the extended buffer. +func AppendDateTimeElement(dst []byte, key string, dt int64) []byte { + return AppendDateTime(AppendHeader(dst, bsontype.DateTime, key), dt) +} + +// ReadDateTime will read an int64 datetime from src. If there are not enough bytes it +// will return false. +func ReadDateTime(src []byte) (int64, []byte, bool) { return readi64(src) } + +// AppendTime will append time as a BSON DateTime to dst and return the extended buffer. +func AppendTime(dst []byte, t time.Time) []byte { + return AppendDateTime(dst, t.Unix()*1000+int64(t.Nanosecond()/1e6)) +} + +// AppendTimeElement will append a BSON datetime element using key and dt to dst +// and return the extended buffer. +func AppendTimeElement(dst []byte, key string, t time.Time) []byte { + return AppendTime(AppendHeader(dst, bsontype.DateTime, key), t) +} + +// ReadTime will read an time.Time datetime from src. If there are not enough bytes it +// will return false. +func ReadTime(src []byte) (time.Time, []byte, bool) { + dt, rem, ok := readi64(src) + return time.Unix(dt/1e3, dt%1e3*1e6), rem, ok +} + +// AppendNullElement will append a BSON null element using key to dst +// and return the extended buffer. +func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst, bsontype.Null, key) } + +// AppendRegex will append pattern and options to dst and return the extended buffer. +func AppendRegex(dst []byte, pattern, options string) []byte { + if !isValidCString(pattern) || !isValidCString(options) { + panic(invalidRegexPanicMsg) + } + + return append(dst, pattern+nullTerminator+options+nullTerminator...) +} + +// AppendRegexElement will append a BSON regex element using key, pattern, and +// options to dst and return the extended buffer. +func AppendRegexElement(dst []byte, key, pattern, options string) []byte { + return AppendRegex(AppendHeader(dst, bsontype.Regex, key), pattern, options) +} + +// ReadRegex will read a pattern and options from src. If there are not enough bytes it +// will return false. +func ReadRegex(src []byte) (pattern, options string, rem []byte, ok bool) { + pattern, rem, ok = readcstring(src) + if !ok { + return "", "", src, false + } + options, rem, ok = readcstring(rem) + if !ok { + return "", "", src, false + } + return pattern, options, rem, true +} + +// AppendDBPointer will append ns and oid to dst and return the extended buffer. +func AppendDBPointer(dst []byte, ns string, oid primitive.ObjectID) []byte { + return append(appendstring(dst, ns), oid[:]...) +} + +// AppendDBPointerElement will append a BSON DBPointer element using key, ns, +// and oid to dst and return the extended buffer. +func AppendDBPointerElement(dst []byte, key, ns string, oid primitive.ObjectID) []byte { + return AppendDBPointer(AppendHeader(dst, bsontype.DBPointer, key), ns, oid) +} + +// ReadDBPointer will read a ns and oid from src. If there are not enough bytes it +// will return false. +func ReadDBPointer(src []byte) (ns string, oid primitive.ObjectID, rem []byte, ok bool) { + ns, rem, ok = readstring(src) + if !ok { + return "", primitive.ObjectID{}, src, false + } + oid, rem, ok = ReadObjectID(rem) + if !ok { + return "", primitive.ObjectID{}, src, false + } + return ns, oid, rem, true +} + +// AppendJavaScript will append js to dst and return the extended buffer. +func AppendJavaScript(dst []byte, js string) []byte { return appendstring(dst, js) } + +// AppendJavaScriptElement will append a BSON JavaScript element using key and +// js to dst and return the extended buffer. +func AppendJavaScriptElement(dst []byte, key, js string) []byte { + return AppendJavaScript(AppendHeader(dst, bsontype.JavaScript, key), js) +} + +// ReadJavaScript will read a js string from src. If there are not enough bytes it +// will return false. +func ReadJavaScript(src []byte) (js string, rem []byte, ok bool) { return readstring(src) } + +// AppendSymbol will append symbol to dst and return the extended buffer. +func AppendSymbol(dst []byte, symbol string) []byte { return appendstring(dst, symbol) } + +// AppendSymbolElement will append a BSON symbol element using key and symbol to dst +// and return the extended buffer. +func AppendSymbolElement(dst []byte, key, symbol string) []byte { + return AppendSymbol(AppendHeader(dst, bsontype.Symbol, key), symbol) +} + +// ReadSymbol will read a symbol string from src. If there are not enough bytes it +// will return false. +func ReadSymbol(src []byte) (symbol string, rem []byte, ok bool) { return readstring(src) } + +// AppendCodeWithScope will append code and scope to dst and return the extended buffer. +func AppendCodeWithScope(dst []byte, code string, scope []byte) []byte { + length := int32(4 + 4 + len(code) + 1 + len(scope)) // length of cws, length of code, code, 0x00, scope + dst = appendLength(dst, length) + + return append(appendstring(dst, code), scope...) +} + +// AppendCodeWithScopeElement will append a BSON code with scope element using +// key, code, and scope to dst +// and return the extended buffer. +func AppendCodeWithScopeElement(dst []byte, key, code string, scope []byte) []byte { + return AppendCodeWithScope(AppendHeader(dst, bsontype.CodeWithScope, key), code, scope) +} + +// ReadCodeWithScope will read code and scope from src. If there are not enough bytes it +// will return false. +func ReadCodeWithScope(src []byte) (code string, scope []byte, rem []byte, ok bool) { + length, rem, ok := ReadLength(src) + if !ok || len(src) < int(length) { + return "", nil, src, false + } + + code, rem, ok = readstring(rem) + if !ok { + return "", nil, src, false + } + + scope, rem, ok = ReadDocument(rem) + if !ok { + return "", nil, src, false + } + return code, scope, rem, true +} + +// AppendInt32 will append i32 to dst and return the extended buffer. +func AppendInt32(dst []byte, i32 int32) []byte { return appendi32(dst, i32) } + +// AppendInt32Element will append a BSON int32 element using key and i32 to dst +// and return the extended buffer. +func AppendInt32Element(dst []byte, key string, i32 int32) []byte { + return AppendInt32(AppendHeader(dst, bsontype.Int32, key), i32) +} + +// ReadInt32 will read an int32 from src. If there are not enough bytes it +// will return false. +func ReadInt32(src []byte) (int32, []byte, bool) { return readi32(src) } + +// AppendTimestamp will append t and i to dst and return the extended buffer. +func AppendTimestamp(dst []byte, t, i uint32) []byte { + return appendu32(appendu32(dst, i), t) // i is the lower 4 bytes, t is the higher 4 bytes +} + +// AppendTimestampElement will append a BSON timestamp element using key, t, and +// i to dst and return the extended buffer. +func AppendTimestampElement(dst []byte, key string, t, i uint32) []byte { + return AppendTimestamp(AppendHeader(dst, bsontype.Timestamp, key), t, i) +} + +// ReadTimestamp will read t and i from src. If there are not enough bytes it +// will return false. +func ReadTimestamp(src []byte) (t, i uint32, rem []byte, ok bool) { + i, rem, ok = readu32(src) + if !ok { + return 0, 0, src, false + } + t, rem, ok = readu32(rem) + if !ok { + return 0, 0, src, false + } + return t, i, rem, true +} + +// AppendInt64 will append i64 to dst and return the extended buffer. +func AppendInt64(dst []byte, i64 int64) []byte { return appendi64(dst, i64) } + +// AppendInt64Element will append a BSON int64 element using key and i64 to dst +// and return the extended buffer. +func AppendInt64Element(dst []byte, key string, i64 int64) []byte { + return AppendInt64(AppendHeader(dst, bsontype.Int64, key), i64) +} + +// ReadInt64 will read an int64 from src. If there are not enough bytes it +// will return false. +func ReadInt64(src []byte) (int64, []byte, bool) { return readi64(src) } + +// AppendDecimal128 will append d128 to dst and return the extended buffer. +func AppendDecimal128(dst []byte, d128 primitive.Decimal128) []byte { + high, low := d128.GetBytes() + return appendu64(appendu64(dst, low), high) +} + +// AppendDecimal128Element will append a BSON primitive.28 element using key and +// d128 to dst and return the extended buffer. +func AppendDecimal128Element(dst []byte, key string, d128 primitive.Decimal128) []byte { + return AppendDecimal128(AppendHeader(dst, bsontype.Decimal128, key), d128) +} + +// ReadDecimal128 will read a primitive.Decimal128 from src. If there are not enough bytes it +// will return false. +func ReadDecimal128(src []byte) (primitive.Decimal128, []byte, bool) { + l, rem, ok := readu64(src) + if !ok { + return primitive.Decimal128{}, src, false + } + + h, rem, ok := readu64(rem) + if !ok { + return primitive.Decimal128{}, src, false + } + + return primitive.NewDecimal128(h, l), rem, true +} + +// AppendMaxKeyElement will append a BSON max key element using key to dst +// and return the extended buffer. +func AppendMaxKeyElement(dst []byte, key string) []byte { + return AppendHeader(dst, bsontype.MaxKey, key) +} + +// AppendMinKeyElement will append a BSON min key element using key to dst +// and return the extended buffer. +func AppendMinKeyElement(dst []byte, key string) []byte { + return AppendHeader(dst, bsontype.MinKey, key) +} + +// EqualValue will return true if the two values are equal. +func EqualValue(t1, t2 bsontype.Type, v1, v2 []byte) bool { + if t1 != t2 { + return false + } + v1, _, ok := readValue(v1, t1) + if !ok { + return false + } + v2, _, ok = readValue(v2, t2) + if !ok { + return false + } + return bytes.Equal(v1, v2) +} + +// valueLength will determine the length of the next value contained in src as if it +// is type t. The returned bool will be false if there are not enough bytes in src for +// a value of type t. +func valueLength(src []byte, t bsontype.Type) (int32, bool) { + var length int32 + ok := true + switch t { + case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope: + length, _, ok = ReadLength(src) + case bsontype.Binary: + length, _, ok = ReadLength(src) + length += 4 + 1 // binary length + subtype byte + case bsontype.Boolean: + length = 1 + case bsontype.DBPointer: + length, _, ok = ReadLength(src) + length += 4 + 12 // string length + ObjectID length + case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp: + length = 8 + case bsontype.Decimal128: + length = 16 + case bsontype.Int32: + length = 4 + case bsontype.JavaScript, bsontype.String, bsontype.Symbol: + length, _, ok = ReadLength(src) + length += 4 + case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined: + length = 0 + case bsontype.ObjectID: + length = 12 + case bsontype.Regex: + regex := bytes.IndexByte(src, 0x00) + if regex < 0 { + ok = false + break + } + pattern := bytes.IndexByte(src[regex+1:], 0x00) + if pattern < 0 { + ok = false + break + } + length = int32(int64(regex) + 1 + int64(pattern) + 1) + default: + ok = false + } + + return length, ok +} + +func readValue(src []byte, t bsontype.Type) ([]byte, []byte, bool) { + length, ok := valueLength(src, t) + if !ok || int(length) > len(src) { + return nil, src, false + } + + return src[:length], src[length:], true +} + +// ReserveLength reserves the space required for length and returns the index where to write the length +// and the []byte with reserved space. +func ReserveLength(dst []byte) (int32, []byte) { + index := len(dst) + return int32(index), append(dst, 0x00, 0x00, 0x00, 0x00) +} + +// UpdateLength updates the length at index with length and returns the []byte. +func UpdateLength(dst []byte, index, length int32) []byte { + dst[index] = byte(length) + dst[index+1] = byte(length >> 8) + dst[index+2] = byte(length >> 16) + dst[index+3] = byte(length >> 24) + return dst +} + +func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) } + +func appendi32(dst []byte, i32 int32) []byte { + return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24)) +} + +// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If +// there aren't enough bytes to read a valid length, src is returned unomdified and the returned +// bool will be false. +func ReadLength(src []byte) (int32, []byte, bool) { + ln, src, ok := readi32(src) + if ln < 0 { + return ln, src, false + } + return ln, src, ok +} + +func readi32(src []byte) (int32, []byte, bool) { + if len(src) < 4 { + return 0, src, false + } + return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true +} + +func appendi64(dst []byte, i64 int64) []byte { + return append(dst, + byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24), + byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56), + ) +} + +func readi64(src []byte) (int64, []byte, bool) { + if len(src) < 8 { + return 0, src, false + } + i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 | + int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56) + return i64, src[8:], true +} + +func appendu32(dst []byte, u32 uint32) []byte { + return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24)) +} + +func readu32(src []byte) (uint32, []byte, bool) { + if len(src) < 4 { + return 0, src, false + } + + return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true +} + +func appendu64(dst []byte, u64 uint64) []byte { + return append(dst, + byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24), + byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56), + ) +} + +func readu64(src []byte) (uint64, []byte, bool) { + if len(src) < 8 { + return 0, src, false + } + u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 | + uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56) + return u64, src[8:], true +} + +// keep in sync with readcstringbytes +func readcstring(src []byte) (string, []byte, bool) { + idx := bytes.IndexByte(src, 0x00) + if idx < 0 { + return "", src, false + } + return string(src[:idx]), src[idx+1:], true +} + +// keep in sync with readcstring +func readcstringbytes(src []byte) ([]byte, []byte, bool) { + idx := bytes.IndexByte(src, 0x00) + if idx < 0 { + return nil, src, false + } + return src[:idx], src[idx+1:], true +} + +func appendstring(dst []byte, s string) []byte { + l := int32(len(s) + 1) + dst = appendLength(dst, l) + dst = append(dst, s...) + return append(dst, 0x00) +} + +func readstring(src []byte) (string, []byte, bool) { + l, rem, ok := ReadLength(src) + if !ok { + return "", src, false + } + if len(src[4:]) < int(l) || l == 0 { + return "", src, false + } + + return string(rem[:l-1]), rem[l:], true +} + +// readLengthBytes attempts to read a length and that number of bytes. This +// function requires that the length include the four bytes for itself. +func readLengthBytes(src []byte) ([]byte, []byte, bool) { + l, _, ok := ReadLength(src) + if !ok { + return nil, src, false + } + if len(src) < int(l) { + return nil, src, false + } + return src[:l], src[l:], true +} + +func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte { + dst = appendLength(dst, int32(len(b)+4)) // The bytes we'll encode need to be 4 larger for the length bytes + dst = append(dst, subtype) + dst = appendLength(dst, int32(len(b))) + return append(dst, b...) +} + +func isValidCString(cs string) bool { + return !strings.ContainsRune(cs, '\x00') +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/compression.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/compression.go new file mode 100644 index 0000000000..c474714ff4 --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/compression.go @@ -0,0 +1,145 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driver + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" + + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" +) + +// CompressionOpts holds settings for how to compress a payload +type CompressionOpts struct { + Compressor wiremessage.CompressorID + ZlibLevel int + ZstdLevel int + UncompressedSize int32 +} + +var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder + +func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { + if v, ok := zstdEncoders.Load(level); ok { + return v.(*zstd.Encoder), nil + } + encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) + if err != nil { + return nil, err + } + zstdEncoders.Store(level, encoder) + return encoder, nil +} + +var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder + +func getZlibEncoder(level int) (*zlibEncoder, error) { + if v, ok := zlibEncoders.Load(level); ok { + return v.(*zlibEncoder), nil + } + writer, err := zlib.NewWriterLevel(nil, level) + if err != nil { + return nil, err + } + encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} + zlibEncoders.Store(level, encoder) + + return encoder, nil +} + +type zlibEncoder struct { + mu sync.Mutex + writer *zlib.Writer + buf *bytes.Buffer +} + +func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { + e.mu.Lock() + defer e.mu.Unlock() + + e.buf.Reset() + e.writer.Reset(e.buf) + + _, err := e.writer.Write(src) + if err != nil { + return nil, err + } + err = e.writer.Close() + if err != nil { + return nil, err + } + dst = append(dst[:0], e.buf.Bytes()...) + return dst, nil +} + +// CompressPayload takes a byte slice and compresses it according to the options passed +func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + return snappy.Encode(nil, in), nil + case wiremessage.CompressorZLib: + encoder, err := getZlibEncoder(opts.ZlibLevel) + if err != nil { + return nil, err + } + return encoder.Encode(nil, in) + case wiremessage.CompressorZstd: + encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel)) + if err != nil { + return nil, err + } + return encoder.EncodeAll(in, nil), nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} + +// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed +func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { + switch opts.Compressor { + case wiremessage.CompressorNoOp: + return in, nil + case wiremessage.CompressorSnappy: + uncompressed = make([]byte, opts.UncompressedSize) + return snappy.Decode(uncompressed, in) + case wiremessage.CompressorZLib: + r, err := zlib.NewReader(bytes.NewReader(in)) + if err != nil { + return nil, err + } + defer func() { + err = r.Close() + }() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + case wiremessage.CompressorZstd: + r, err := zstd.NewReader(bytes.NewBuffer(in)) + if err != nil { + return nil, err + } + defer r.Close() + uncompressed = make([]byte, opts.UncompressedSize) + _, err = io.ReadFull(r, uncompressed) + if err != nil { + return nil, err + } + return uncompressed, nil + default: + return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) + } +} diff --git a/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session.go b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session.go new file mode 100644 index 0000000000..ba244b101e --- /dev/null +++ b/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session.go @@ -0,0 +1,539 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package session // import "go.mongodb.org/mongo-driver/x/mongo/driver/session" + +import ( + "context" + "errors" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/internal/uuid" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/mongo/writeconcern" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +// ErrSessionEnded is returned when a client session is used after a call to endSession(). +var ErrSessionEnded = errors.New("ended session was used") + +// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started. +var ErrNoTransactStarted = errors.New("no transaction started") + +// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress. +var ErrTransactInProgress = errors.New("transaction already in progress") + +// ErrAbortAfterCommit is returned when abort is called after a commit. +var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction") + +// ErrAbortTwice is returned if abort is called after transaction is already aborted. +var ErrAbortTwice = errors.New("cannot call abortTransaction twice") + +// ErrCommitAfterAbort is returned if commit is called after an abort. +var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction") + +// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaction. +var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns") + +// ErrSnapshotTransaction is returned if an transaction is started on a snapshot session. +var ErrSnapshotTransaction = errors.New("transactions are not supported in snapshot sessions") + +// TransactionState indicates the state of the transactions FSM. +type TransactionState uint8 + +// Client Session states +const ( + None TransactionState = iota + Starting + InProgress + Committed + Aborted +) + +// String implements the fmt.Stringer interface. +func (s TransactionState) String() string { + switch s { + case None: + return "none" + case Starting: + return "starting" + case InProgress: + return "in progress" + case Committed: + return "committed" + case Aborted: + return "aborted" + default: + return "unknown" + } +} + +// LoadBalancedTransactionConnection represents a connection that's pinned by a ClientSession because it's being used +// to execute a transaction when running against a load balancer. This interface is a copy of driver.PinnedConnection +// and exists to be able to pin transactions to a connection without causing an import cycle. +type LoadBalancedTransactionConnection interface { + // Functions copied over from driver.Connection. + WriteWireMessage(context.Context, []byte) error + ReadWireMessage(ctx context.Context) ([]byte, error) + Description() description.Server + Close() error + ID() string + ServerConnectionID() *int64 + DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64. + Address() address.Address + Stale() bool + + // Functions copied over from driver.PinnedConnection that are not part of Connection or Expirable. + PinToCursor() error + PinToTransaction() error + UnpinFromCursor() error + UnpinFromTransaction() error +} + +// Client is a session for clients to run commands. +type Client struct { + *Server + ClientID uuid.UUID + ClusterTime bson.Raw + Consistent bool // causal consistency + OperationTime *primitive.Timestamp + IsImplicit bool + Terminated bool + RetryingCommit bool + Committing bool + Aborting bool + RetryWrite bool + RetryRead bool + Snapshot bool + + // options for the current transaction + // most recently set by transactionopt + CurrentRc *readconcern.ReadConcern + CurrentRp *readpref.ReadPref + CurrentWc *writeconcern.WriteConcern + CurrentMct *time.Duration + + // default transaction options + transactionRc *readconcern.ReadConcern + transactionRp *readpref.ReadPref + transactionWc *writeconcern.WriteConcern + transactionMaxCommitTime *time.Duration + + pool *Pool + TransactionState TransactionState + PinnedServer *description.Server + RecoveryToken bson.Raw + PinnedConnection LoadBalancedTransactionConnection + SnapshotTime *primitive.Timestamp +} + +func getClusterTime(clusterTime bson.Raw) (uint32, uint32) { + if clusterTime == nil { + return 0, 0 + } + + clusterTimeVal, err := clusterTime.LookupErr("$clusterTime") + if err != nil { + return 0, 0 + } + + timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime") + if err != nil { + return 0, 0 + } + + return timestampVal.Timestamp() +} + +// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time. +func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw { + epoch1, ord1 := getClusterTime(ct1) + epoch2, ord2 := getClusterTime(ct2) + + if epoch1 > epoch2 { + return ct1 + } else if epoch1 < epoch2 { + return ct2 + } else if ord1 > ord2 { + return ct1 + } else if ord1 < ord2 { + return ct2 + } + + return ct1 +} + +// NewImplicitClientSession creates a new implicit client-side session. +func NewImplicitClientSession(pool *Pool, clientID uuid.UUID) *Client { + // Server-side session checkout for implicit sessions is deferred until after checking out a + // connection, so don't check out a server-side session right now. This will limit the number of + // implicit sessions to no greater than an application's maxPoolSize. + + return &Client{ + pool: pool, + ClientID: clientID, + IsImplicit: true, + } +} + +// NewClientSession creates a new explicit client-side session. +func NewClientSession(pool *Pool, clientID uuid.UUID, opts ...*ClientOptions) (*Client, error) { + c := &Client{ + pool: pool, + ClientID: clientID, + } + + mergedOpts := mergeClientOptions(opts...) + if mergedOpts.DefaultReadPreference != nil { + c.transactionRp = mergedOpts.DefaultReadPreference + } + if mergedOpts.DefaultReadConcern != nil { + c.transactionRc = mergedOpts.DefaultReadConcern + } + if mergedOpts.DefaultWriteConcern != nil { + c.transactionWc = mergedOpts.DefaultWriteConcern + } + if mergedOpts.DefaultMaxCommitTime != nil { + c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime + } + if mergedOpts.Snapshot != nil { + c.Snapshot = *mergedOpts.Snapshot + } + + // For explicit sessions, the default for causalConsistency is true, unless Snapshot is + // enabled, then it's false. Set the default and then allow any explicit causalConsistency + // setting to override it. + c.Consistent = !c.Snapshot + if mergedOpts.CausalConsistency != nil { + c.Consistent = *mergedOpts.CausalConsistency + } + + if c.Consistent && c.Snapshot { + return nil, errors.New("causal consistency and snapshot cannot both be set for a session") + } + + if err := c.SetServer(); err != nil { + return nil, err + } + + return c, nil +} + +// SetServer will check out a session from the client session pool. +func (c *Client) SetServer() error { + var err error + c.Server, err = c.pool.GetSession() + return err +} + +// AdvanceClusterTime updates the session's cluster time. +func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error { + if c.Terminated { + return ErrSessionEnded + } + c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime) + return nil +} + +// AdvanceOperationTime updates the session's operation time. +func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error { + if c.Terminated { + return ErrSessionEnded + } + + if c.OperationTime == nil { + c.OperationTime = opTime + return nil + } + + if opTime.T > c.OperationTime.T { + c.OperationTime = opTime + } else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) { + c.OperationTime = opTime + } + + return nil +} + +// UpdateUseTime sets the session's last used time to the current time. This must be called whenever the session is +// used to send a command to the server to ensure that the session is not prematurely marked expired in the driver's +// session pool. If the session has already been ended, this method will return ErrSessionEnded. +func (c *Client) UpdateUseTime() error { + if c.Terminated { + return ErrSessionEnded + } + c.updateUseTime() + return nil +} + +// UpdateRecoveryToken updates the session's recovery token from the server response. +func (c *Client) UpdateRecoveryToken(response bson.Raw) { + if c == nil { + return + } + + token, err := response.LookupErr("recoveryToken") + if err != nil { + return + } + + c.RecoveryToken = token.Document() +} + +// UpdateSnapshotTime updates the session's value for the atClusterTime field of ReadConcern. +func (c *Client) UpdateSnapshotTime(response bsoncore.Document) { + if c == nil { + return + } + + subDoc := response + if cur, ok := response.Lookup("cursor").DocumentOK(); ok { + subDoc = cur + } + + ssTimeElem, err := subDoc.LookupErr("atClusterTime") + if err != nil { + // atClusterTime not included by the server + return + } + + t, i := ssTimeElem.Timestamp() + c.SnapshotTime = &primitive.Timestamp{ + T: t, + I: i, + } +} + +// ClearPinnedResources clears the pinned server and/or connection associated with the session. +func (c *Client) ClearPinnedResources() error { + if c == nil { + return nil + } + + c.PinnedServer = nil + if c.PinnedConnection != nil { + if err := c.PinnedConnection.UnpinFromTransaction(); err != nil { + return err + } + if err := c.PinnedConnection.Close(); err != nil { + return err + } + } + c.PinnedConnection = nil + return nil +} + +// UnpinConnection gracefully unpins the connection associated with the session if there is one. This is done via +// the pinned connection's UnpinFromTransaction function. +func (c *Client) UnpinConnection() error { + if c == nil || c.PinnedConnection == nil { + return nil + } + + err := c.PinnedConnection.UnpinFromTransaction() + closeErr := c.PinnedConnection.Close() + if err == nil && closeErr != nil { + err = closeErr + } + c.PinnedConnection = nil + return err +} + +// EndSession ends the session. +func (c *Client) EndSession() { + if c.Terminated { + return + } + c.Terminated = true + c.pool.ReturnSession(c.Server) +} + +// TransactionInProgress returns true if the client session is in an active transaction. +func (c *Client) TransactionInProgress() bool { + return c.TransactionState == InProgress +} + +// TransactionStarting returns true if the client session is starting a transaction. +func (c *Client) TransactionStarting() bool { + return c.TransactionState == Starting +} + +// TransactionRunning returns true if the client session has started the transaction +// and it hasn't been committed or aborted +func (c *Client) TransactionRunning() bool { + return c != nil && (c.TransactionState == Starting || c.TransactionState == InProgress) +} + +// TransactionCommitted returns true of the client session just committed a transaction. +func (c *Client) TransactionCommitted() bool { + return c.TransactionState == Committed +} + +// CheckStartTransaction checks to see if allowed to start transaction and returns +// an error if not allowed +func (c *Client) CheckStartTransaction() error { + if c.TransactionState == InProgress || c.TransactionState == Starting { + return ErrTransactInProgress + } + if c.Snapshot { + return ErrSnapshotTransaction + } + return nil +} + +// StartTransaction initializes the transaction options and advances the state machine. +// It does not contact the server to start the transaction. +func (c *Client) StartTransaction(opts *TransactionOptions) error { + err := c.CheckStartTransaction() + if err != nil { + return err + } + + c.IncrementTxnNumber() + c.RetryingCommit = false + + if opts != nil { + c.CurrentRc = opts.ReadConcern + c.CurrentRp = opts.ReadPreference + c.CurrentWc = opts.WriteConcern + c.CurrentMct = opts.MaxCommitTime + } + + if c.CurrentRc == nil { + c.CurrentRc = c.transactionRc + } + + if c.CurrentRp == nil { + c.CurrentRp = c.transactionRp + } + + if c.CurrentWc == nil { + c.CurrentWc = c.transactionWc + } + + if c.CurrentMct == nil { + c.CurrentMct = c.transactionMaxCommitTime + } + + if !writeconcern.AckWrite(c.CurrentWc) { + _ = c.clearTransactionOpts() + return ErrUnackWCUnsupported + } + + c.TransactionState = Starting + return c.ClearPinnedResources() +} + +// CheckCommitTransaction checks to see if allowed to commit transaction and returns +// an error if not allowed. +func (c *Client) CheckCommitTransaction() error { + if c.TransactionState == None { + return ErrNoTransactStarted + } else if c.TransactionState == Aborted { + return ErrCommitAfterAbort + } + return nil +} + +// CommitTransaction updates the state for a successfully committed transaction and returns +// an error if not permissible. It does not actually perform the commit. +func (c *Client) CommitTransaction() error { + err := c.CheckCommitTransaction() + if err != nil { + return err + } + c.TransactionState = Committed + return nil +} + +// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set a +// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a +// retryable error or after a successful commit transaction operation. +func (c *Client) UpdateCommitTransactionWriteConcern() { + wc := c.CurrentWc + timeout := 10 * time.Second + if wc != nil && wc.GetWTimeout() != 0 { + timeout = wc.GetWTimeout() + } + c.CurrentWc = wc.WithOptions(writeconcern.WMajority(), writeconcern.WTimeout(timeout)) +} + +// CheckAbortTransaction checks to see if allowed to abort transaction and returns +// an error if not allowed. +func (c *Client) CheckAbortTransaction() error { + if c.TransactionState == None { + return ErrNoTransactStarted + } else if c.TransactionState == Committed { + return ErrAbortAfterCommit + } else if c.TransactionState == Aborted { + return ErrAbortTwice + } + return nil +} + +// AbortTransaction updates the state for a successfully aborted transaction and returns +// an error if not permissible. It does not actually perform the abort. +func (c *Client) AbortTransaction() error { + err := c.CheckAbortTransaction() + if err != nil { + return err + } + c.TransactionState = Aborted + return c.clearTransactionOpts() +} + +// StartCommand updates the session's internal state at the beginning of an operation. This must be called before +// server selection is done for the operation as the session's state can impact the result of that process. +func (c *Client) StartCommand() error { + if c == nil { + return nil + } + + // If we're executing the first operation using this session after a transaction, we must ensure that the session + // is not pinned to any resources. + if !c.TransactionRunning() && !c.Committing && !c.Aborting { + return c.ClearPinnedResources() + } + return nil +} + +// ApplyCommand advances the state machine upon command execution. This must be called after server selection is +// complete. +func (c *Client) ApplyCommand(desc description.Server) error { + if c.Committing { + // Do not change state if committing after already committed + return nil + } + if c.TransactionState == Starting { + c.TransactionState = InProgress + // If this is in a transaction and the server is a mongos, pin it + if desc.Kind == description.Mongos { + c.PinnedServer = &desc + } + } else if c.TransactionState == Committed || c.TransactionState == Aborted { + c.TransactionState = None + return c.clearTransactionOpts() + } + + return nil +} + +func (c *Client) clearTransactionOpts() error { + c.RetryingCommit = false + c.Aborting = false + c.Committing = false + c.CurrentWc = nil + c.CurrentRp = nil + c.CurrentRc = nil + c.RecoveryToken = nil + + return c.ClearPinnedResources() +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/config.go b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/config.go new file mode 100644 index 0000000000..e3e1d452b0 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/config.go @@ -0,0 +1,133 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package otelgrpc // import "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + +import ( + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/propagation" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" + "go.opentelemetry.io/otel/trace" +) + +const ( + // instrumentationName is the name of this instrumentation package. + instrumentationName = "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + // GRPCStatusCodeKey is convention for numeric status code of a gRPC request. + GRPCStatusCodeKey = attribute.Key("rpc.grpc.status_code") +) + +// Filter is a predicate used to determine whether a given request in +// interceptor info should be traced. A Filter must return true if +// the request should be traced. +type Filter func(*InterceptorInfo) bool + +// config is a group of options for this instrumentation. +type config struct { + Filter Filter + Propagators propagation.TextMapPropagator + TracerProvider trace.TracerProvider + MeterProvider metric.MeterProvider + + meter metric.Meter + rpcServerDuration metric.Int64Histogram +} + +// Option applies an option value for a config. +type Option interface { + apply(*config) +} + +// newConfig returns a config configured with all the passed Options. +func newConfig(opts []Option) *config { + c := &config{ + Propagators: otel.GetTextMapPropagator(), + TracerProvider: otel.GetTracerProvider(), + MeterProvider: otel.GetMeterProvider(), + } + for _, o := range opts { + o.apply(c) + } + + c.meter = c.MeterProvider.Meter( + instrumentationName, + metric.WithInstrumentationVersion(Version()), + metric.WithSchemaURL(semconv.SchemaURL), + ) + var err error + if c.rpcServerDuration, err = c.meter.Int64Histogram("rpc.server.duration", metric.WithUnit("ms")); err != nil { + otel.Handle(err) + } + + return c +} + +type propagatorsOption struct{ p propagation.TextMapPropagator } + +func (o propagatorsOption) apply(c *config) { + if o.p != nil { + c.Propagators = o.p + } +} + +// WithPropagators returns an Option to use the Propagators when extracting +// and injecting trace context from requests. +func WithPropagators(p propagation.TextMapPropagator) Option { + return propagatorsOption{p: p} +} + +type tracerProviderOption struct{ tp trace.TracerProvider } + +func (o tracerProviderOption) apply(c *config) { + if o.tp != nil { + c.TracerProvider = o.tp + } +} + +// WithInterceptorFilter returns an Option to use the request filter. +func WithInterceptorFilter(f Filter) Option { + return interceptorFilterOption{f: f} +} + +type interceptorFilterOption struct { + f Filter +} + +func (o interceptorFilterOption) apply(c *config) { + if o.f != nil { + c.Filter = o.f + } +} + +// WithTracerProvider returns an Option to use the TracerProvider when +// creating a Tracer. +func WithTracerProvider(tp trace.TracerProvider) Option { + return tracerProviderOption{tp: tp} +} + +type meterProviderOption struct{ mp metric.MeterProvider } + +func (o meterProviderOption) apply(c *config) { + if o.mp != nil { + c.MeterProvider = o.mp + } +} + +// WithMeterProvider returns an Option to use the MeterProvider when +// creating a Meter. If this option is not provide the global MeterProvider will be used. +func WithMeterProvider(mp metric.MeterProvider) Option { + return meterProviderOption{mp: mp} +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/interceptor.go b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/interceptor.go new file mode 100644 index 0000000000..d4dc5de5a9 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/interceptor.go @@ -0,0 +1,527 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package otelgrpc // import "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + +// gRPC tracing middleware +// https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/rpc.md +import ( + "context" + "io" + "net" + "strconv" + "time" + + "google.golang.org/grpc" + grpc_codes "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" + "go.opentelemetry.io/otel/trace" +) + +type messageType attribute.KeyValue + +// Event adds an event of the messageType to the span associated with the +// passed context with a message id. +func (m messageType) Event(ctx context.Context, id int, _ interface{}) { + span := trace.SpanFromContext(ctx) + if !span.IsRecording() { + return + } + span.AddEvent("message", trace.WithAttributes( + attribute.KeyValue(m), + RPCMessageIDKey.Int(id), + )) +} + +var ( + messageSent = messageType(RPCMessageTypeSent) + messageReceived = messageType(RPCMessageTypeReceived) +) + +// UnaryClientInterceptor returns a grpc.UnaryClientInterceptor suitable +// for use in a grpc.Dial call. +func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { + cfg := newConfig(opts) + tracer := cfg.TracerProvider.Tracer( + instrumentationName, + trace.WithInstrumentationVersion(Version()), + ) + + return func( + ctx context.Context, + method string, + req, reply interface{}, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + callOpts ...grpc.CallOption, + ) error { + i := &InterceptorInfo{ + Method: method, + Type: UnaryClient, + } + if cfg.Filter != nil && !cfg.Filter(i) { + return invoker(ctx, method, req, reply, cc, callOpts...) + } + + name, attr := spanInfo(method, cc.Target()) + var span trace.Span + ctx, span = tracer.Start( + ctx, + name, + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(attr...), + ) + defer span.End() + + ctx = inject(ctx, cfg.Propagators) + + messageSent.Event(ctx, 1, req) + + err := invoker(ctx, method, req, reply, cc, callOpts...) + + messageReceived.Event(ctx, 1, reply) + + if err != nil { + s, _ := status.FromError(err) + span.SetStatus(codes.Error, s.Message()) + span.SetAttributes(statusCodeAttr(s.Code())) + } else { + span.SetAttributes(statusCodeAttr(grpc_codes.OK)) + } + + return err + } +} + +type streamEventType int + +type streamEvent struct { + Type streamEventType + Err error +} + +const ( + receiveEndEvent streamEventType = iota + errorEvent +) + +// clientStream wraps around the embedded grpc.ClientStream, and intercepts the RecvMsg and +// SendMsg method call. +type clientStream struct { + grpc.ClientStream + + desc *grpc.StreamDesc + events chan streamEvent + eventsDone chan struct{} + finished chan error + + receivedMessageID int + sentMessageID int +} + +var _ = proto.Marshal + +func (w *clientStream) RecvMsg(m interface{}) error { + err := w.ClientStream.RecvMsg(m) + + if err == nil && !w.desc.ServerStreams { + w.sendStreamEvent(receiveEndEvent, nil) + } else if err == io.EOF { + w.sendStreamEvent(receiveEndEvent, nil) + } else if err != nil { + w.sendStreamEvent(errorEvent, err) + } else { + w.receivedMessageID++ + messageReceived.Event(w.Context(), w.receivedMessageID, m) + } + + return err +} + +func (w *clientStream) SendMsg(m interface{}) error { + err := w.ClientStream.SendMsg(m) + + w.sentMessageID++ + messageSent.Event(w.Context(), w.sentMessageID, m) + + if err != nil { + w.sendStreamEvent(errorEvent, err) + } + + return err +} + +func (w *clientStream) Header() (metadata.MD, error) { + md, err := w.ClientStream.Header() + + if err != nil { + w.sendStreamEvent(errorEvent, err) + } + + return md, err +} + +func (w *clientStream) CloseSend() error { + err := w.ClientStream.CloseSend() + + if err != nil { + w.sendStreamEvent(errorEvent, err) + } + + return err +} + +func wrapClientStream(ctx context.Context, s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream { + events := make(chan streamEvent) + eventsDone := make(chan struct{}) + finished := make(chan error) + + go func() { + defer close(eventsDone) + + for { + select { + case event := <-events: + switch event.Type { + case receiveEndEvent: + finished <- nil + return + case errorEvent: + finished <- event.Err + return + } + case <-ctx.Done(): + finished <- ctx.Err() + return + } + } + }() + + return &clientStream{ + ClientStream: s, + desc: desc, + events: events, + eventsDone: eventsDone, + finished: finished, + } +} + +func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) { + select { + case <-w.eventsDone: + case w.events <- streamEvent{Type: eventType, Err: err}: + } +} + +// StreamClientInterceptor returns a grpc.StreamClientInterceptor suitable +// for use in a grpc.Dial call. +func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { + cfg := newConfig(opts) + tracer := cfg.TracerProvider.Tracer( + instrumentationName, + trace.WithInstrumentationVersion(Version()), + ) + + return func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + callOpts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + i := &InterceptorInfo{ + Method: method, + Type: StreamClient, + } + if cfg.Filter != nil && !cfg.Filter(i) { + return streamer(ctx, desc, cc, method, callOpts...) + } + + name, attr := spanInfo(method, cc.Target()) + var span trace.Span + ctx, span = tracer.Start( + ctx, + name, + trace.WithSpanKind(trace.SpanKindClient), + trace.WithAttributes(attr...), + ) + + ctx = inject(ctx, cfg.Propagators) + + s, err := streamer(ctx, desc, cc, method, callOpts...) + if err != nil { + grpcStatus, _ := status.FromError(err) + span.SetStatus(codes.Error, grpcStatus.Message()) + span.SetAttributes(statusCodeAttr(grpcStatus.Code())) + span.End() + return s, err + } + stream := wrapClientStream(ctx, s, desc) + + go func() { + err := <-stream.finished + + if err != nil { + s, _ := status.FromError(err) + span.SetStatus(codes.Error, s.Message()) + span.SetAttributes(statusCodeAttr(s.Code())) + } else { + span.SetAttributes(statusCodeAttr(grpc_codes.OK)) + } + + span.End() + }() + + return stream, nil + } +} + +// UnaryServerInterceptor returns a grpc.UnaryServerInterceptor suitable +// for use in a grpc.NewServer call. +func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { + cfg := newConfig(opts) + tracer := cfg.TracerProvider.Tracer( + instrumentationName, + trace.WithInstrumentationVersion(Version()), + ) + + return func( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, + ) (interface{}, error) { + i := &InterceptorInfo{ + UnaryServerInfo: info, + Type: UnaryServer, + } + if cfg.Filter != nil && !cfg.Filter(i) { + return handler(ctx, req) + } + + ctx = extract(ctx, cfg.Propagators) + + name, attr := spanInfo(info.FullMethod, peerFromCtx(ctx)) + ctx, span := tracer.Start( + trace.ContextWithRemoteSpanContext(ctx, trace.SpanContextFromContext(ctx)), + name, + trace.WithSpanKind(trace.SpanKindServer), + trace.WithAttributes(attr...), + ) + defer span.End() + + messageReceived.Event(ctx, 1, req) + + var statusCode grpc_codes.Code + defer func(t time.Time) { + elapsedTime := time.Since(t) / time.Millisecond + attr = append(attr, semconv.RPCGRPCStatusCodeKey.Int64(int64(statusCode))) + o := metric.WithAttributes(attr...) + cfg.rpcServerDuration.Record(ctx, int64(elapsedTime), o) + }(time.Now()) + + resp, err := handler(ctx, req) + if err != nil { + s, _ := status.FromError(err) + statusCode, msg := serverStatus(s) + span.SetStatus(statusCode, msg) + span.SetAttributes(statusCodeAttr(s.Code())) + messageSent.Event(ctx, 1, s.Proto()) + } else { + statusCode = grpc_codes.OK + span.SetAttributes(statusCodeAttr(grpc_codes.OK)) + messageSent.Event(ctx, 1, resp) + } + + return resp, err + } +} + +// serverStream wraps around the embedded grpc.ServerStream, and intercepts the RecvMsg and +// SendMsg method call. +type serverStream struct { + grpc.ServerStream + ctx context.Context + + receivedMessageID int + sentMessageID int +} + +func (w *serverStream) Context() context.Context { + return w.ctx +} + +func (w *serverStream) RecvMsg(m interface{}) error { + err := w.ServerStream.RecvMsg(m) + + if err == nil { + w.receivedMessageID++ + messageReceived.Event(w.Context(), w.receivedMessageID, m) + } + + return err +} + +func (w *serverStream) SendMsg(m interface{}) error { + err := w.ServerStream.SendMsg(m) + + w.sentMessageID++ + messageSent.Event(w.Context(), w.sentMessageID, m) + + return err +} + +func wrapServerStream(ctx context.Context, ss grpc.ServerStream) *serverStream { + return &serverStream{ + ServerStream: ss, + ctx: ctx, + } +} + +// StreamServerInterceptor returns a grpc.StreamServerInterceptor suitable +// for use in a grpc.NewServer call. +func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { + cfg := newConfig(opts) + tracer := cfg.TracerProvider.Tracer( + instrumentationName, + trace.WithInstrumentationVersion(Version()), + ) + + return func( + srv interface{}, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, + ) error { + ctx := ss.Context() + i := &InterceptorInfo{ + StreamServerInfo: info, + Type: StreamServer, + } + if cfg.Filter != nil && !cfg.Filter(i) { + return handler(srv, wrapServerStream(ctx, ss)) + } + + ctx = extract(ctx, cfg.Propagators) + + name, attr := spanInfo(info.FullMethod, peerFromCtx(ctx)) + ctx, span := tracer.Start( + trace.ContextWithRemoteSpanContext(ctx, trace.SpanContextFromContext(ctx)), + name, + trace.WithSpanKind(trace.SpanKindServer), + trace.WithAttributes(attr...), + ) + defer span.End() + + err := handler(srv, wrapServerStream(ctx, ss)) + if err != nil { + s, _ := status.FromError(err) + statusCode, msg := serverStatus(s) + span.SetStatus(statusCode, msg) + span.SetAttributes(statusCodeAttr(s.Code())) + } else { + span.SetAttributes(statusCodeAttr(grpc_codes.OK)) + } + + return err + } +} + +// spanInfo returns a span name and all appropriate attributes from the gRPC +// method and peer address. +func spanInfo(fullMethod, peerAddress string) (string, []attribute.KeyValue) { + attrs := []attribute.KeyValue{RPCSystemGRPC} + name, mAttrs := internal.ParseFullMethod(fullMethod) + attrs = append(attrs, mAttrs...) + attrs = append(attrs, peerAttr(peerAddress)...) + return name, attrs +} + +// peerAttr returns attributes about the peer address. +func peerAttr(addr string) []attribute.KeyValue { + host, p, err := net.SplitHostPort(addr) + if err != nil { + return []attribute.KeyValue(nil) + } + + if host == "" { + host = "127.0.0.1" + } + port, err := strconv.Atoi(p) + if err != nil { + return []attribute.KeyValue(nil) + } + + var attr []attribute.KeyValue + if ip := net.ParseIP(host); ip != nil { + attr = []attribute.KeyValue{ + semconv.NetSockPeerAddr(host), + semconv.NetSockPeerPort(port), + } + } else { + attr = []attribute.KeyValue{ + semconv.NetPeerName(host), + semconv.NetPeerPort(port), + } + } + + return attr +} + +// peerFromCtx returns a peer address from a context, if one exists. +func peerFromCtx(ctx context.Context) string { + p, ok := peer.FromContext(ctx) + if !ok { + return "" + } + return p.Addr.String() +} + +// statusCodeAttr returns status code attribute based on given gRPC code. +func statusCodeAttr(c grpc_codes.Code) attribute.KeyValue { + return GRPCStatusCodeKey.Int64(int64(c)) +} + +// serverStatus returns a span status code and message for a given gRPC +// status code. It maps specific gRPC status codes to a corresponding span +// status code and message. This function is intended for use on the server +// side of a gRPC connection. +// +// If the gRPC status code is Unknown, DeadlineExceeded, Unimplemented, +// Internal, Unavailable, or DataLoss, it returns a span status code of Error +// and the message from the gRPC status. Otherwise, it returns a span status +// code of Unset and an empty message. +func serverStatus(grpcStatus *status.Status) (codes.Code, string) { + switch grpcStatus.Code() { + case grpc_codes.Unknown, + grpc_codes.DeadlineExceeded, + grpc_codes.Unimplemented, + grpc_codes.Internal, + grpc_codes.Unavailable, + grpc_codes.DataLoss: + return codes.Error, grpcStatus.Message() + default: + return codes.Unset, "" + } +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal/parse.go b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal/parse.go new file mode 100644 index 0000000000..ae160d5875 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal/parse.go @@ -0,0 +1,43 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal // import "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal" + +import ( + "strings" + + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +// ParseFullMethod returns a span name following the OpenTelemetry semantic +// conventions as well as all applicable span attribute.KeyValue attributes based +// on a gRPC's FullMethod. +func ParseFullMethod(fullMethod string) (string, []attribute.KeyValue) { + name := strings.TrimLeft(fullMethod, "/") + service, method, found := strings.Cut(name, "/") + if !found { + // Invalid format, does not follow `/package.service/method`. + return name, nil + } + + var attrs []attribute.KeyValue + if service != "" { + attrs = append(attrs, semconv.RPCService(service)) + } + if method != "" { + attrs = append(attrs, semconv.RPCMethod(method)) + } + return name, attrs +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/version.go b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/version.go new file mode 100644 index 0000000000..1fc5e3365d --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/version.go @@ -0,0 +1,28 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package otelgrpc // import "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + +// Version is the current release version of the gRPC instrumentation. +func Version() string { + return "0.42.0" + // This string is updated by the pre_release.sh script during release +} + +// SemVersion is the semantic version to be supplied to tracer/meter creation. +// +// Deprecated: Use [Version] instead. +func SemVersion() string { + return Version() +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/handler.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/handler.go new file mode 100644 index 0000000000..f2f20e3b93 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/handler.go @@ -0,0 +1,264 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package otelhttp // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + +import ( + "io" + "net/http" + "time" + + "github.com/felixge/httpsnoop" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/propagation" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" + "go.opentelemetry.io/otel/semconv/v1.17.0/httpconv" + "go.opentelemetry.io/otel/trace" +) + +var _ http.Handler = &Handler{} + +// Handler is http middleware that corresponds to the http.Handler interface and +// is designed to wrap a http.Mux (or equivalent), while individual routes on +// the mux are wrapped with WithRouteTag. A Handler will add various attributes +// to the span using the attribute.Keys defined in this package. +type Handler struct { + operation string + server string + handler http.Handler + + tracer trace.Tracer + meter metric.Meter + propagators propagation.TextMapPropagator + spanStartOptions []trace.SpanStartOption + readEvent bool + writeEvent bool + filters []Filter + spanNameFormatter func(string, *http.Request) string + counters map[string]metric.Int64Counter + valueRecorders map[string]metric.Float64Histogram + publicEndpoint bool + publicEndpointFn func(*http.Request) bool +} + +func defaultHandlerFormatter(operation string, _ *http.Request) string { + return operation +} + +// NewHandler wraps the passed handler, functioning like middleware, in a span +// named after the operation and with any provided Options. +func NewHandler(handler http.Handler, operation string, opts ...Option) http.Handler { + h := Handler{ + handler: handler, + operation: operation, + } + + defaultOpts := []Option{ + WithSpanOptions(trace.WithSpanKind(trace.SpanKindServer)), + WithSpanNameFormatter(defaultHandlerFormatter), + } + + c := newConfig(append(defaultOpts, opts...)...) + h.configure(c) + h.createMeasures() + + return &h +} + +func (h *Handler) configure(c *config) { + h.tracer = c.Tracer + h.meter = c.Meter + h.propagators = c.Propagators + h.spanStartOptions = c.SpanStartOptions + h.readEvent = c.ReadEvent + h.writeEvent = c.WriteEvent + h.filters = c.Filters + h.spanNameFormatter = c.SpanNameFormatter + h.publicEndpoint = c.PublicEndpoint + h.publicEndpointFn = c.PublicEndpointFn + h.server = c.ServerName +} + +func handleErr(err error) { + if err != nil { + otel.Handle(err) + } +} + +func (h *Handler) createMeasures() { + h.counters = make(map[string]metric.Int64Counter) + h.valueRecorders = make(map[string]metric.Float64Histogram) + + requestBytesCounter, err := h.meter.Int64Counter(RequestContentLength) + handleErr(err) + + responseBytesCounter, err := h.meter.Int64Counter(ResponseContentLength) + handleErr(err) + + serverLatencyMeasure, err := h.meter.Float64Histogram(ServerLatency) + handleErr(err) + + h.counters[RequestContentLength] = requestBytesCounter + h.counters[ResponseContentLength] = responseBytesCounter + h.valueRecorders[ServerLatency] = serverLatencyMeasure +} + +// ServeHTTP serves HTTP requests (http.Handler). +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + requestStartTime := time.Now() + for _, f := range h.filters { + if !f(r) { + // Simply pass through to the handler if a filter rejects the request + h.handler.ServeHTTP(w, r) + return + } + } + + ctx := h.propagators.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + opts := []trace.SpanStartOption{ + trace.WithAttributes(httpconv.ServerRequest(h.server, r)...), + } + if h.server != "" { + hostAttr := semconv.NetHostName(h.server) + opts = append(opts, trace.WithAttributes(hostAttr)) + } + opts = append(opts, h.spanStartOptions...) + if h.publicEndpoint || (h.publicEndpointFn != nil && h.publicEndpointFn(r.WithContext(ctx))) { + opts = append(opts, trace.WithNewRoot()) + // Linking incoming span context if any for public endpoint. + if s := trace.SpanContextFromContext(ctx); s.IsValid() && s.IsRemote() { + opts = append(opts, trace.WithLinks(trace.Link{SpanContext: s})) + } + } + + tracer := h.tracer + + if tracer == nil { + if span := trace.SpanFromContext(r.Context()); span.SpanContext().IsValid() { + tracer = newTracer(span.TracerProvider()) + } else { + tracer = newTracer(otel.GetTracerProvider()) + } + } + + ctx, span := tracer.Start(ctx, h.spanNameFormatter(h.operation, r), opts...) + defer span.End() + + readRecordFunc := func(int64) {} + if h.readEvent { + readRecordFunc = func(n int64) { + span.AddEvent("read", trace.WithAttributes(ReadBytesKey.Int64(n))) + } + } + + var bw bodyWrapper + // if request body is nil or NoBody, we don't want to mutate the body as it + // will affect the identity of it in an unforeseeable way because we assert + // ReadCloser fulfills a certain interface and it is indeed nil or NoBody. + if r.Body != nil && r.Body != http.NoBody { + bw.ReadCloser = r.Body + bw.record = readRecordFunc + r.Body = &bw + } + + writeRecordFunc := func(int64) {} + if h.writeEvent { + writeRecordFunc = func(n int64) { + span.AddEvent("write", trace.WithAttributes(WroteBytesKey.Int64(n))) + } + } + + rww := &respWriterWrapper{ + ResponseWriter: w, + record: writeRecordFunc, + ctx: ctx, + props: h.propagators, + statusCode: http.StatusOK, // default status code in case the Handler doesn't write anything + } + + // Wrap w to use our ResponseWriter methods while also exposing + // other interfaces that w may implement (http.CloseNotifier, + // http.Flusher, http.Hijacker, http.Pusher, io.ReaderFrom). + + w = httpsnoop.Wrap(w, httpsnoop.Hooks{ + Header: func(httpsnoop.HeaderFunc) httpsnoop.HeaderFunc { + return rww.Header + }, + Write: func(httpsnoop.WriteFunc) httpsnoop.WriteFunc { + return rww.Write + }, + WriteHeader: func(httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc { + return rww.WriteHeader + }, + }) + + labeler := &Labeler{} + ctx = injectLabeler(ctx, labeler) + + h.handler.ServeHTTP(w, r.WithContext(ctx)) + + setAfterServeAttributes(span, bw.read, rww.written, rww.statusCode, bw.err, rww.err) + + // Add metrics + attributes := append(labeler.Get(), httpconv.ServerRequest(h.server, r)...) + if rww.statusCode > 0 { + attributes = append(attributes, semconv.HTTPStatusCode(rww.statusCode)) + } + o := metric.WithAttributes(attributes...) + h.counters[RequestContentLength].Add(ctx, bw.read, o) + h.counters[ResponseContentLength].Add(ctx, rww.written, o) + + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedTime := float64(time.Since(requestStartTime)) / float64(time.Millisecond) + + h.valueRecorders[ServerLatency].Record(ctx, elapsedTime, o) +} + +func setAfterServeAttributes(span trace.Span, read, wrote int64, statusCode int, rerr, werr error) { + attributes := []attribute.KeyValue{} + + // TODO: Consider adding an event after each read and write, possibly as an + // option (defaulting to off), so as to not create needlessly verbose spans. + if read > 0 { + attributes = append(attributes, ReadBytesKey.Int64(read)) + } + if rerr != nil && rerr != io.EOF { + attributes = append(attributes, ReadErrorKey.String(rerr.Error())) + } + if wrote > 0 { + attributes = append(attributes, WroteBytesKey.Int64(wrote)) + } + if statusCode > 0 { + attributes = append(attributes, semconv.HTTPStatusCode(statusCode)) + } + span.SetStatus(httpconv.ServerStatus(statusCode)) + + if werr != nil && werr != io.EOF { + attributes = append(attributes, WriteErrorKey.String(werr.Error())) + } + span.SetAttributes(attributes...) +} + +// WithRouteTag annotates a span with the provided route name using the +// RouteKey Tag. +func WithRouteTag(route string, h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + span := trace.SpanFromContext(r.Context()) + span.SetAttributes(semconv.HTTPRoute(route)) + h.ServeHTTP(w, r) + }) +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport.go new file mode 100644 index 0000000000..9dda7e1a95 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/transport.go @@ -0,0 +1,193 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package otelhttp // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + +import ( + "context" + "io" + "net/http" + "net/http/httptrace" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/semconv/v1.17.0/httpconv" + "go.opentelemetry.io/otel/trace" +) + +// Transport implements the http.RoundTripper interface and wraps +// outbound HTTP(S) requests with a span. +type Transport struct { + rt http.RoundTripper + + tracer trace.Tracer + propagators propagation.TextMapPropagator + spanStartOptions []trace.SpanStartOption + filters []Filter + spanNameFormatter func(string, *http.Request) string + clientTrace func(context.Context) *httptrace.ClientTrace +} + +var _ http.RoundTripper = &Transport{} + +// NewTransport wraps the provided http.RoundTripper with one that +// starts a span and injects the span context into the outbound request headers. +// +// If the provided http.RoundTripper is nil, http.DefaultTransport will be used +// as the base http.RoundTripper. +func NewTransport(base http.RoundTripper, opts ...Option) *Transport { + if base == nil { + base = http.DefaultTransport + } + + t := Transport{ + rt: base, + } + + defaultOpts := []Option{ + WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)), + WithSpanNameFormatter(defaultTransportFormatter), + } + + c := newConfig(append(defaultOpts, opts...)...) + t.applyConfig(c) + + return &t +} + +func (t *Transport) applyConfig(c *config) { + t.tracer = c.Tracer + t.propagators = c.Propagators + t.spanStartOptions = c.SpanStartOptions + t.filters = c.Filters + t.spanNameFormatter = c.SpanNameFormatter + t.clientTrace = c.ClientTrace +} + +func defaultTransportFormatter(_ string, r *http.Request) string { + return "HTTP " + r.Method +} + +// RoundTrip creates a Span and propagates its context via the provided request's headers +// before handing the request to the configured base RoundTripper. The created span will +// end when the response body is closed or when a read from the body returns io.EOF. +func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) { + for _, f := range t.filters { + if !f(r) { + // Simply pass through to the base RoundTripper if a filter rejects the request + return t.rt.RoundTrip(r) + } + } + + tracer := t.tracer + + if tracer == nil { + if span := trace.SpanFromContext(r.Context()); span.SpanContext().IsValid() { + tracer = newTracer(span.TracerProvider()) + } else { + tracer = newTracer(otel.GetTracerProvider()) + } + } + + opts := append([]trace.SpanStartOption{}, t.spanStartOptions...) // start with the configured options + + ctx, span := tracer.Start(r.Context(), t.spanNameFormatter("", r), opts...) + + if t.clientTrace != nil { + ctx = httptrace.WithClientTrace(ctx, t.clientTrace(ctx)) + } + + r = r.WithContext(ctx) + span.SetAttributes(httpconv.ClientRequest(r)...) + t.propagators.Inject(ctx, propagation.HeaderCarrier(r.Header)) + + res, err := t.rt.RoundTrip(r) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + span.End() + return res, err + } + + span.SetAttributes(httpconv.ClientResponse(res)...) + span.SetStatus(httpconv.ClientStatus(res.StatusCode)) + res.Body = newWrappedBody(span, res.Body) + + return res, err +} + +// newWrappedBody returns a new and appropriately scoped *wrappedBody as an +// io.ReadCloser. If the passed body implements io.Writer, the returned value +// will implement io.ReadWriteCloser. +func newWrappedBody(span trace.Span, body io.ReadCloser) io.ReadCloser { + // The successful protocol switch responses will have a body that + // implement an io.ReadWriteCloser. Ensure this interface type continues + // to be satisfied if that is the case. + if _, ok := body.(io.ReadWriteCloser); ok { + return &wrappedBody{span: span, body: body} + } + + // Remove the implementation of the io.ReadWriteCloser and only implement + // the io.ReadCloser. + return struct{ io.ReadCloser }{&wrappedBody{span: span, body: body}} +} + +// wrappedBody is the response body type returned by the transport +// instrumentation to complete a span. Errors encountered when using the +// response body are recorded in span tracking the response. +// +// The span tracking the response is ended when this body is closed. +// +// If the response body implements the io.Writer interface (i.e. for +// successful protocol switches), the wrapped body also will. +type wrappedBody struct { + span trace.Span + body io.ReadCloser +} + +var _ io.ReadWriteCloser = &wrappedBody{} + +func (wb *wrappedBody) Write(p []byte) (int, error) { + // This will not panic given the guard in newWrappedBody. + n, err := wb.body.(io.Writer).Write(p) + if err != nil { + wb.span.RecordError(err) + wb.span.SetStatus(codes.Error, err.Error()) + } + return n, err +} + +func (wb *wrappedBody) Read(b []byte) (int, error) { + n, err := wb.body.Read(b) + + switch err { + case nil: + // nothing to do here but fall through to the return + case io.EOF: + wb.span.End() + default: + wb.span.RecordError(err) + wb.span.SetStatus(codes.Error, err.Error()) + } + return n, err +} + +func (wb *wrappedBody) Close() error { + wb.span.End() + if wb.body != nil { + return wb.body.Close() + } + return nil +} diff --git a/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/version.go b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/version.go new file mode 100644 index 0000000000..bbcbb74160 --- /dev/null +++ b/vendor/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp/version.go @@ -0,0 +1,28 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package otelhttp // import "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + +// Version is the current release version of the otelhttp instrumentation. +func Version() string { + return "0.42.0" + // This string is updated by the pre_release.sh script during release +} + +// SemVersion is the semantic version to be supplied to tracer/meter creation. +// +// Deprecated: Use [Version] instead. +func SemVersion() string { + return Version() +} diff --git a/vendor/go.opentelemetry.io/otel/exporters/jaeger/README.md b/vendor/go.opentelemetry.io/otel/exporters/jaeger/README.md new file mode 100644 index 0000000000..19060ba4fd --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/exporters/jaeger/README.md @@ -0,0 +1,50 @@ +# OpenTelemetry-Go Jaeger Exporter + +[![Go Reference](https://pkg.go.dev/badge/go.opentelemetry.io/otel/exporters/jaeger.svg)](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger) + +[OpenTelemetry span exporter for Jaeger](https://github.com/open-telemetry/opentelemetry-specification/blob/v1.20.0/specification/trace/sdk_exporters/jaeger.md) implementation. + +## Installation + +``` +go get -u go.opentelemetry.io/otel/exporters/jaeger +``` + +## Example + +See [../../example/jaeger](../../example/jaeger). + +## Configuration + +The exporter can be used to send spans to: + +- Jaeger agent using `jaeger.thrift` over compact thrift protocol via + [`WithAgentEndpoint`](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger#WithAgentEndpoint) option. +- Jaeger collector using `jaeger.thrift` over HTTP via + [`WithCollectorEndpoint`](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger#WithCollectorEndpoint) option. + +### Environment Variables + +The following environment variables can be used +(instead of options objects) to override the default configuration. + +| Environment variable | Option | Default value | +| --------------------------------- | --------------------------------------------------------------------------------------------- | ----------------------------------- | +| `OTEL_EXPORTER_JAEGER_AGENT_HOST` | [`WithAgentHost`](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger#WithAgentHost) | `localhost` | +| `OTEL_EXPORTER_JAEGER_AGENT_PORT` | [`WithAgentPort`](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger#WithAgentPort) | `6831` | +| `OTEL_EXPORTER_JAEGER_ENDPOINT` | [`WithEndpoint`](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger#WithEndpoint) | `http://localhost:14268/api/traces` | +| `OTEL_EXPORTER_JAEGER_USER` | [`WithUsername`](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger#WithUsername) | | +| `OTEL_EXPORTER_JAEGER_PASSWORD` | [`WithPassword`](https://pkg.go.dev/go.opentelemetry.io/otel/exporters/jaeger#WithPassword) | | + +Configuration using options have precedence over the environment variables. + +## Contributing + +This exporter uses a vendored copy of the Apache Thrift library (v0.14.1) at a custom import path. +When re-generating Thrift code in the future, please adapt import paths as necessary. + +## References + +- [Jaeger](https://www.jaegertracing.io/) +- [OpenTelemetry to Jaeger Transformation](https://github.com/open-telemetry/opentelemetry-specification/blob/v1.20.0/specification/trace/sdk_exporters/jaeger.md) +- [OpenTelemetry Environment Variable Specification](https://github.com/open-telemetry/opentelemetry-specification/blob/v1.20.0/specification/sdk-environment-variables.md#jaeger-exporter) diff --git a/vendor/go.opentelemetry.io/otel/exporters/jaeger/doc.go b/vendor/go.opentelemetry.io/otel/exporters/jaeger/doc.go new file mode 100644 index 0000000000..0d7ba86764 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/exporters/jaeger/doc.go @@ -0,0 +1,16 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package jaeger contains an OpenTelemetry tracing exporter for Jaeger. +package jaeger // import "go.opentelemetry.io/otel/exporters/jaeger" diff --git a/vendor/go.opentelemetry.io/otel/exporters/jaeger/jaeger.go b/vendor/go.opentelemetry.io/otel/exporters/jaeger/jaeger.go new file mode 100644 index 0000000000..ddbd681d00 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/exporters/jaeger/jaeger.go @@ -0,0 +1,360 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package jaeger // import "go.opentelemetry.io/otel/exporters/jaeger" + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "sync" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + gen "go.opentelemetry.io/otel/exporters/jaeger/internal/gen-go/jaeger" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" + "go.opentelemetry.io/otel/trace" +) + +const ( + keyInstrumentationLibraryName = "otel.library.name" + keyInstrumentationLibraryVersion = "otel.library.version" + keyError = "error" + keySpanKind = "span.kind" + keyStatusCode = "otel.status_code" + keyStatusMessage = "otel.status_description" + keyDroppedAttributeCount = "otel.event.dropped_attributes_count" + keyEventName = "event" +) + +// New returns an OTel Exporter implementation that exports the collected +// spans to Jaeger. +func New(endpointOption EndpointOption) (*Exporter, error) { + uploader, err := endpointOption.newBatchUploader() + if err != nil { + return nil, err + } + + // Fetch default service.name from default resource for backup + var defaultServiceName string + defaultResource := resource.Default() + if value, exists := defaultResource.Set().Value(semconv.ServiceNameKey); exists { + defaultServiceName = value.AsString() + } + if defaultServiceName == "" { + return nil, fmt.Errorf("failed to get service name from default resource") + } + + stopCh := make(chan struct{}) + e := &Exporter{ + uploader: uploader, + stopCh: stopCh, + defaultServiceName: defaultServiceName, + } + return e, nil +} + +// Exporter exports OpenTelemetry spans to a Jaeger agent or collector. +type Exporter struct { + uploader batchUploader + stopOnce sync.Once + stopCh chan struct{} + defaultServiceName string +} + +var _ sdktrace.SpanExporter = (*Exporter)(nil) + +// ExportSpans transforms and exports OpenTelemetry spans to Jaeger. +func (e *Exporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error { + // Return fast if context is already canceled or Exporter shutdown. + select { + case <-ctx.Done(): + return ctx.Err() + case <-e.stopCh: + return nil + default: + } + + // Cancel export if Exporter is shutdown. + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + go func(ctx context.Context, cancel context.CancelFunc) { + select { + case <-ctx.Done(): + case <-e.stopCh: + cancel() + } + }(ctx, cancel) + + for _, batch := range jaegerBatchList(spans, e.defaultServiceName) { + if err := e.uploader.upload(ctx, batch); err != nil { + return err + } + } + + return nil +} + +// Shutdown stops the Exporter. This will close all connections and release +// all resources held by the Exporter. +func (e *Exporter) Shutdown(ctx context.Context) error { + // Stop any active and subsequent exports. + e.stopOnce.Do(func() { close(e.stopCh) }) + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + return e.uploader.shutdown(ctx) +} + +// MarshalLog is the marshaling function used by the logging system to represent this exporter. +func (e *Exporter) MarshalLog() interface{} { + return struct { + Type string + }{ + Type: "jaeger", + } +} + +func spanToThrift(ss sdktrace.ReadOnlySpan) *gen.Span { + attr := ss.Attributes() + tags := make([]*gen.Tag, 0, len(attr)) + for _, kv := range attr { + tag := keyValueToTag(kv) + if tag != nil { + tags = append(tags, tag) + } + } + + if is := ss.InstrumentationScope(); is.Name != "" { + tags = append(tags, getStringTag(keyInstrumentationLibraryName, is.Name)) + if is.Version != "" { + tags = append(tags, getStringTag(keyInstrumentationLibraryVersion, is.Version)) + } + } + + if ss.SpanKind() != trace.SpanKindInternal { + tags = append(tags, + getStringTag(keySpanKind, ss.SpanKind().String()), + ) + } + + if ss.Status().Code != codes.Unset { + switch ss.Status().Code { + case codes.Ok: + tags = append(tags, getStringTag(keyStatusCode, "OK")) + case codes.Error: + tags = append(tags, getBoolTag(keyError, true)) + tags = append(tags, getStringTag(keyStatusCode, "ERROR")) + } + if ss.Status().Description != "" { + tags = append(tags, getStringTag(keyStatusMessage, ss.Status().Description)) + } + } + + var logs []*gen.Log + for _, a := range ss.Events() { + nTags := len(a.Attributes) + if a.Name != "" { + nTags++ + } + if a.DroppedAttributeCount != 0 { + nTags++ + } + fields := make([]*gen.Tag, 0, nTags) + if a.Name != "" { + // If an event contains an attribute with the same key, it needs + // to be given precedence and overwrite this. + fields = append(fields, getStringTag(keyEventName, a.Name)) + } + for _, kv := range a.Attributes { + tag := keyValueToTag(kv) + if tag != nil { + fields = append(fields, tag) + } + } + if a.DroppedAttributeCount != 0 { + fields = append(fields, getInt64Tag(keyDroppedAttributeCount, int64(a.DroppedAttributeCount))) + } + logs = append(logs, &gen.Log{ + Timestamp: a.Time.UnixNano() / 1000, + Fields: fields, + }) + } + + var refs []*gen.SpanRef + for _, link := range ss.Links() { + tid := link.SpanContext.TraceID() + sid := link.SpanContext.SpanID() + refs = append(refs, &gen.SpanRef{ + TraceIdHigh: int64(binary.BigEndian.Uint64(tid[0:8])), + TraceIdLow: int64(binary.BigEndian.Uint64(tid[8:16])), + SpanId: int64(binary.BigEndian.Uint64(sid[:])), + RefType: gen.SpanRefType_FOLLOWS_FROM, + }) + } + + tid := ss.SpanContext().TraceID() + sid := ss.SpanContext().SpanID() + psid := ss.Parent().SpanID() + return &gen.Span{ + TraceIdHigh: int64(binary.BigEndian.Uint64(tid[0:8])), + TraceIdLow: int64(binary.BigEndian.Uint64(tid[8:16])), + SpanId: int64(binary.BigEndian.Uint64(sid[:])), + ParentSpanId: int64(binary.BigEndian.Uint64(psid[:])), + OperationName: ss.Name(), // TODO: if span kind is added then add prefix "Sent"/"Recv" + Flags: int32(ss.SpanContext().TraceFlags()), + StartTime: ss.StartTime().UnixNano() / 1000, + Duration: ss.EndTime().Sub(ss.StartTime()).Nanoseconds() / 1000, + Tags: tags, + Logs: logs, + References: refs, + } +} + +func keyValueToTag(keyValue attribute.KeyValue) *gen.Tag { + var tag *gen.Tag + switch keyValue.Value.Type() { + case attribute.STRING: + s := keyValue.Value.AsString() + tag = &gen.Tag{ + Key: string(keyValue.Key), + VStr: &s, + VType: gen.TagType_STRING, + } + case attribute.BOOL: + b := keyValue.Value.AsBool() + tag = &gen.Tag{ + Key: string(keyValue.Key), + VBool: &b, + VType: gen.TagType_BOOL, + } + case attribute.INT64: + i := keyValue.Value.AsInt64() + tag = &gen.Tag{ + Key: string(keyValue.Key), + VLong: &i, + VType: gen.TagType_LONG, + } + case attribute.FLOAT64: + f := keyValue.Value.AsFloat64() + tag = &gen.Tag{ + Key: string(keyValue.Key), + VDouble: &f, + VType: gen.TagType_DOUBLE, + } + case attribute.BOOLSLICE, + attribute.INT64SLICE, + attribute.FLOAT64SLICE, + attribute.STRINGSLICE: + data, _ := json.Marshal(keyValue.Value.AsInterface()) + a := (string)(data) + tag = &gen.Tag{ + Key: string(keyValue.Key), + VStr: &a, + VType: gen.TagType_STRING, + } + } + return tag +} + +func getInt64Tag(k string, i int64) *gen.Tag { + return &gen.Tag{ + Key: k, + VLong: &i, + VType: gen.TagType_LONG, + } +} + +func getStringTag(k, s string) *gen.Tag { + return &gen.Tag{ + Key: k, + VStr: &s, + VType: gen.TagType_STRING, + } +} + +func getBoolTag(k string, b bool) *gen.Tag { + return &gen.Tag{ + Key: k, + VBool: &b, + VType: gen.TagType_BOOL, + } +} + +// jaegerBatchList transforms a slice of spans into a slice of jaeger Batch. +func jaegerBatchList(ssl []sdktrace.ReadOnlySpan, defaultServiceName string) []*gen.Batch { + if len(ssl) == 0 { + return nil + } + + batchDict := make(map[attribute.Distinct]*gen.Batch) + + for _, ss := range ssl { + if ss == nil { + continue + } + + resourceKey := ss.Resource().Equivalent() + batch, bOK := batchDict[resourceKey] + if !bOK { + batch = &gen.Batch{ + Process: process(ss.Resource(), defaultServiceName), + Spans: []*gen.Span{}, + } + } + batch.Spans = append(batch.Spans, spanToThrift(ss)) + batchDict[resourceKey] = batch + } + + // Transform the categorized map into a slice + batchList := make([]*gen.Batch, 0, len(batchDict)) + for _, batch := range batchDict { + batchList = append(batchList, batch) + } + return batchList +} + +// process transforms an OTel Resource into a jaeger Process. +func process(res *resource.Resource, defaultServiceName string) *gen.Process { + var process gen.Process + + var serviceName attribute.KeyValue + if res != nil { + for iter := res.Iter(); iter.Next(); { + if iter.Attribute().Key == semconv.ServiceNameKey { + serviceName = iter.Attribute() + // Don't convert service.name into tag. + continue + } + if tag := keyValueToTag(iter.Attribute()); tag != nil { + process.Tags = append(process.Tags, tag) + } + } + } + + // If no service.name is contained in a Span's Resource, + // that field MUST be populated from the default Resource. + if serviceName.Value.AsString() == "" { + serviceName = semconv.ServiceName(defaultServiceName) + } + process.ServiceName = serviceName.Value.AsString() + + return &process +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/builtin.go b/vendor/go.opentelemetry.io/otel/sdk/resource/builtin.go new file mode 100644 index 0000000000..72320ca51f --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/builtin.go @@ -0,0 +1,108 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +type ( + // telemetrySDK is a Detector that provides information about + // the OpenTelemetry SDK used. This Detector is included as a + // builtin. If these resource attributes are not wanted, use + // the WithTelemetrySDK(nil) or WithoutBuiltin() options to + // explicitly disable them. + telemetrySDK struct{} + + // host is a Detector that provides information about the host + // being run on. This Detector is included as a builtin. If + // these resource attributes are not wanted, use the + // WithHost(nil) or WithoutBuiltin() options to explicitly + // disable them. + host struct{} + + stringDetector struct { + schemaURL string + K attribute.Key + F func() (string, error) + } + + defaultServiceNameDetector struct{} +) + +var ( + _ Detector = telemetrySDK{} + _ Detector = host{} + _ Detector = stringDetector{} + _ Detector = defaultServiceNameDetector{} +) + +// Detect returns a *Resource that describes the OpenTelemetry SDK used. +func (telemetrySDK) Detect(context.Context) (*Resource, error) { + return NewWithAttributes( + semconv.SchemaURL, + semconv.TelemetrySDKName("opentelemetry"), + semconv.TelemetrySDKLanguageGo, + semconv.TelemetrySDKVersion(sdk.Version()), + ), nil +} + +// Detect returns a *Resource that describes the host being run on. +func (host) Detect(ctx context.Context) (*Resource, error) { + return StringDetector(semconv.SchemaURL, semconv.HostNameKey, os.Hostname).Detect(ctx) +} + +// StringDetector returns a Detector that will produce a *Resource +// containing the string as a value corresponding to k. The resulting Resource +// will have the specified schemaURL. +func StringDetector(schemaURL string, k attribute.Key, f func() (string, error)) Detector { + return stringDetector{schemaURL: schemaURL, K: k, F: f} +} + +// Detect returns a *Resource that describes the string as a value +// corresponding to attribute.Key as well as the specific schemaURL. +func (sd stringDetector) Detect(ctx context.Context) (*Resource, error) { + value, err := sd.F() + if err != nil { + return nil, fmt.Errorf("%s: %w", string(sd.K), err) + } + a := sd.K.String(value) + if !a.Valid() { + return nil, fmt.Errorf("invalid attribute: %q -> %q", a.Key, a.Value.Emit()) + } + return NewWithAttributes(sd.schemaURL, sd.K.String(value)), nil +} + +// Detect implements Detector. +func (defaultServiceNameDetector) Detect(ctx context.Context) (*Resource, error) { + return StringDetector( + semconv.SchemaURL, + semconv.ServiceNameKey, + func() (string, error) { + executable, err := os.Executable() + if err != nil { + return "unknown_service:go", nil + } + return "unknown_service:" + filepath.Base(executable), nil + }, + ).Detect(ctx) +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/container.go b/vendor/go.opentelemetry.io/otel/sdk/resource/container.go new file mode 100644 index 0000000000..318dcf82fe --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/container.go @@ -0,0 +1,100 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import ( + "bufio" + "context" + "errors" + "io" + "os" + "regexp" + + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +type containerIDProvider func() (string, error) + +var ( + containerID containerIDProvider = getContainerIDFromCGroup + cgroupContainerIDRe = regexp.MustCompile(`^.*/(?:.*-)?([0-9a-f]+)(?:\.|\s*$)`) +) + +type cgroupContainerIDDetector struct{} + +const cgroupPath = "/proc/self/cgroup" + +// Detect returns a *Resource that describes the id of the container. +// If no container id found, an empty resource will be returned. +func (cgroupContainerIDDetector) Detect(ctx context.Context) (*Resource, error) { + containerID, err := containerID() + if err != nil { + return nil, err + } + + if containerID == "" { + return Empty(), nil + } + return NewWithAttributes(semconv.SchemaURL, semconv.ContainerID(containerID)), nil +} + +var ( + defaultOSStat = os.Stat + osStat = defaultOSStat + + defaultOSOpen = func(name string) (io.ReadCloser, error) { + return os.Open(name) + } + osOpen = defaultOSOpen +) + +// getContainerIDFromCGroup returns the id of the container from the cgroup file. +// If no container id found, an empty string will be returned. +func getContainerIDFromCGroup() (string, error) { + if _, err := osStat(cgroupPath); errors.Is(err, os.ErrNotExist) { + // File does not exist, skip + return "", nil + } + + file, err := osOpen(cgroupPath) + if err != nil { + return "", err + } + defer file.Close() + + return getContainerIDFromReader(file), nil +} + +// getContainerIDFromReader returns the id of the container from reader. +func getContainerIDFromReader(reader io.Reader) string { + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + line := scanner.Text() + + if id := getContainerIDFromLine(line); id != "" { + return id + } + } + return "" +} + +// getContainerIDFromLine returns the id of the container from one string line. +func getContainerIDFromLine(line string) string { + matches := cgroupContainerIDRe.FindStringSubmatch(line) + if len(matches) <= 1 { + return "" + } + return matches[1] +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/env.go b/vendor/go.opentelemetry.io/otel/sdk/resource/env.go new file mode 100644 index 0000000000..f09a781906 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/env.go @@ -0,0 +1,108 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import ( + "context" + "fmt" + "net/url" + "os" + "strings" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +const ( + // resourceAttrKey is the environment variable name OpenTelemetry Resource information will be read from. + resourceAttrKey = "OTEL_RESOURCE_ATTRIBUTES" + + // svcNameKey is the environment variable name that Service Name information will be read from. + svcNameKey = "OTEL_SERVICE_NAME" +) + +var ( + // errMissingValue is returned when a resource value is missing. + errMissingValue = fmt.Errorf("%w: missing value", ErrPartialResource) +) + +// fromEnv is a Detector that implements the Detector and collects +// resources from environment. This Detector is included as a +// builtin. +type fromEnv struct{} + +// compile time assertion that FromEnv implements Detector interface. +var _ Detector = fromEnv{} + +// Detect collects resources from environment. +func (fromEnv) Detect(context.Context) (*Resource, error) { + attrs := strings.TrimSpace(os.Getenv(resourceAttrKey)) + svcName := strings.TrimSpace(os.Getenv(svcNameKey)) + + if attrs == "" && svcName == "" { + return Empty(), nil + } + + var res *Resource + + if svcName != "" { + res = NewSchemaless(semconv.ServiceName(svcName)) + } + + r2, err := constructOTResources(attrs) + + // Ensure that the resource with the service name from OTEL_SERVICE_NAME + // takes precedence, if it was defined. + res, err2 := Merge(r2, res) + + if err == nil { + err = err2 + } else if err2 != nil { + err = fmt.Errorf("detecting resources: %s", []string{err.Error(), err2.Error()}) + } + + return res, err +} + +func constructOTResources(s string) (*Resource, error) { + if s == "" { + return Empty(), nil + } + pairs := strings.Split(s, ",") + var attrs []attribute.KeyValue + var invalid []string + for _, p := range pairs { + k, v, found := strings.Cut(p, "=") + if !found { + invalid = append(invalid, p) + continue + } + key := strings.TrimSpace(k) + val, err := url.QueryUnescape(strings.TrimSpace(v)) + if err != nil { + // Retain original value if decoding fails, otherwise it will be + // an empty string. + val = v + otel.Handle(err) + } + attrs = append(attrs, attribute.String(key, val)) + } + var err error + if len(invalid) > 0 { + err = fmt.Errorf("%w: %v", errMissingValue, invalid) + } + return NewSchemaless(attrs...), err +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/host_id.go b/vendor/go.opentelemetry.io/otel/sdk/resource/host_id.go new file mode 100644 index 0000000000..b8e934d4f8 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/host_id.go @@ -0,0 +1,120 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import ( + "context" + "errors" + "strings" + + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +type hostIDProvider func() (string, error) + +var defaultHostIDProvider hostIDProvider = platformHostIDReader.read + +var hostID = defaultHostIDProvider + +type hostIDReader interface { + read() (string, error) +} + +type fileReader func(string) (string, error) + +type commandExecutor func(string, ...string) (string, error) + +// hostIDReaderBSD implements hostIDReader. +type hostIDReaderBSD struct { + execCommand commandExecutor + readFile fileReader +} + +// read attempts to read the machine-id from /etc/hostid. If not found it will +// execute `kenv -q smbios.system.uuid`. If neither location yields an id an +// error will be returned. +func (r *hostIDReaderBSD) read() (string, error) { + if result, err := r.readFile("/etc/hostid"); err == nil { + return strings.TrimSpace(result), nil + } + + if result, err := r.execCommand("kenv", "-q", "smbios.system.uuid"); err == nil { + return strings.TrimSpace(result), nil + } + + return "", errors.New("host id not found in: /etc/hostid or kenv") +} + +// hostIDReaderDarwin implements hostIDReader. +type hostIDReaderDarwin struct { + execCommand commandExecutor +} + +// read executes `ioreg -rd1 -c "IOPlatformExpertDevice"` and parses host id +// from the IOPlatformUUID line. If the command fails or the uuid cannot be +// parsed an error will be returned. +func (r *hostIDReaderDarwin) read() (string, error) { + result, err := r.execCommand("ioreg", "-rd1", "-c", "IOPlatformExpertDevice") + if err != nil { + return "", err + } + + lines := strings.Split(result, "\n") + for _, line := range lines { + if strings.Contains(line, "IOPlatformUUID") { + parts := strings.Split(line, " = ") + if len(parts) == 2 { + return strings.Trim(parts[1], "\""), nil + } + break + } + } + + return "", errors.New("could not parse IOPlatformUUID") +} + +type hostIDReaderLinux struct { + readFile fileReader +} + +// read attempts to read the machine-id from /etc/machine-id followed by +// /var/lib/dbus/machine-id. If neither location yields an ID an error will +// be returned. +func (r *hostIDReaderLinux) read() (string, error) { + if result, err := r.readFile("/etc/machine-id"); err == nil { + return strings.TrimSpace(result), nil + } + + if result, err := r.readFile("/var/lib/dbus/machine-id"); err == nil { + return strings.TrimSpace(result), nil + } + + return "", errors.New("host id not found in: /etc/machine-id or /var/lib/dbus/machine-id") +} + +type hostIDDetector struct{} + +// Detect returns a *Resource containing the platform specific host id. +func (hostIDDetector) Detect(ctx context.Context) (*Resource, error) { + hostID, err := hostID() + if err != nil { + return nil, err + } + + return NewWithAttributes( + semconv.SchemaURL, + semconv.HostID(hostID), + ), nil +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/host_id_readfile.go b/vendor/go.opentelemetry.io/otel/sdk/resource/host_id_readfile.go new file mode 100644 index 0000000000..f92c6dad0f --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/host_id_readfile.go @@ -0,0 +1,28 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux || dragonfly || freebsd || netbsd || openbsd || solaris + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import "os" + +func readFile(filename string) (string, error) { + b, err := os.ReadFile(filename) + if err != nil { + return "", nil + } + + return string(b), nil +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/os.go b/vendor/go.opentelemetry.io/otel/sdk/resource/os.go new file mode 100644 index 0000000000..815fe5c204 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/os.go @@ -0,0 +1,97 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import ( + "context" + "strings" + + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +type osDescriptionProvider func() (string, error) + +var defaultOSDescriptionProvider osDescriptionProvider = platformOSDescription + +var osDescription = defaultOSDescriptionProvider + +func setDefaultOSDescriptionProvider() { + setOSDescriptionProvider(defaultOSDescriptionProvider) +} + +func setOSDescriptionProvider(osDescriptionProvider osDescriptionProvider) { + osDescription = osDescriptionProvider +} + +type osTypeDetector struct{} +type osDescriptionDetector struct{} + +// Detect returns a *Resource that describes the operating system type the +// service is running on. +func (osTypeDetector) Detect(ctx context.Context) (*Resource, error) { + osType := runtimeOS() + + osTypeAttribute := mapRuntimeOSToSemconvOSType(osType) + + return NewWithAttributes( + semconv.SchemaURL, + osTypeAttribute, + ), nil +} + +// Detect returns a *Resource that describes the operating system the +// service is running on. +func (osDescriptionDetector) Detect(ctx context.Context) (*Resource, error) { + description, err := osDescription() + + if err != nil { + return nil, err + } + + return NewWithAttributes( + semconv.SchemaURL, + semconv.OSDescription(description), + ), nil +} + +// mapRuntimeOSToSemconvOSType translates the OS name as provided by the Go runtime +// into an OS type attribute with the corresponding value defined by the semantic +// conventions. In case the provided OS name isn't mapped, it's transformed to lowercase +// and used as the value for the returned OS type attribute. +func mapRuntimeOSToSemconvOSType(osType string) attribute.KeyValue { + // the elements in this map are the intersection between + // available GOOS values and defined semconv OS types + osTypeAttributeMap := map[string]attribute.KeyValue{ + "darwin": semconv.OSTypeDarwin, + "dragonfly": semconv.OSTypeDragonflyBSD, + "freebsd": semconv.OSTypeFreeBSD, + "linux": semconv.OSTypeLinux, + "netbsd": semconv.OSTypeNetBSD, + "openbsd": semconv.OSTypeOpenBSD, + "solaris": semconv.OSTypeSolaris, + "windows": semconv.OSTypeWindows, + } + + var osTypeAttribute attribute.KeyValue + + if attr, ok := osTypeAttributeMap[osType]; ok { + osTypeAttribute = attr + } else { + osTypeAttribute = semconv.OSTypeKey.String(strings.ToLower(osType)) + } + + return osTypeAttribute +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/process.go b/vendor/go.opentelemetry.io/otel/sdk/resource/process.go new file mode 100644 index 0000000000..bdd0e7fe68 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/process.go @@ -0,0 +1,180 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import ( + "context" + "fmt" + "os" + "os/user" + "path/filepath" + "runtime" + + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +type pidProvider func() int +type executablePathProvider func() (string, error) +type commandArgsProvider func() []string +type ownerProvider func() (*user.User, error) +type runtimeNameProvider func() string +type runtimeVersionProvider func() string +type runtimeOSProvider func() string +type runtimeArchProvider func() string + +var ( + defaultPidProvider pidProvider = os.Getpid + defaultExecutablePathProvider executablePathProvider = os.Executable + defaultCommandArgsProvider commandArgsProvider = func() []string { return os.Args } + defaultOwnerProvider ownerProvider = user.Current + defaultRuntimeNameProvider runtimeNameProvider = func() string { + if runtime.Compiler == "gc" { + return "go" + } + return runtime.Compiler + } + defaultRuntimeVersionProvider runtimeVersionProvider = runtime.Version + defaultRuntimeOSProvider runtimeOSProvider = func() string { return runtime.GOOS } + defaultRuntimeArchProvider runtimeArchProvider = func() string { return runtime.GOARCH } +) + +var ( + pid = defaultPidProvider + executablePath = defaultExecutablePathProvider + commandArgs = defaultCommandArgsProvider + owner = defaultOwnerProvider + runtimeName = defaultRuntimeNameProvider + runtimeVersion = defaultRuntimeVersionProvider + runtimeOS = defaultRuntimeOSProvider + runtimeArch = defaultRuntimeArchProvider +) + +func setDefaultOSProviders() { + setOSProviders( + defaultPidProvider, + defaultExecutablePathProvider, + defaultCommandArgsProvider, + ) +} + +func setOSProviders( + pidProvider pidProvider, + executablePathProvider executablePathProvider, + commandArgsProvider commandArgsProvider, +) { + pid = pidProvider + executablePath = executablePathProvider + commandArgs = commandArgsProvider +} + +func setDefaultRuntimeProviders() { + setRuntimeProviders( + defaultRuntimeNameProvider, + defaultRuntimeVersionProvider, + defaultRuntimeOSProvider, + defaultRuntimeArchProvider, + ) +} + +func setRuntimeProviders( + runtimeNameProvider runtimeNameProvider, + runtimeVersionProvider runtimeVersionProvider, + runtimeOSProvider runtimeOSProvider, + runtimeArchProvider runtimeArchProvider, +) { + runtimeName = runtimeNameProvider + runtimeVersion = runtimeVersionProvider + runtimeOS = runtimeOSProvider + runtimeArch = runtimeArchProvider +} + +func setDefaultUserProviders() { + setUserProviders(defaultOwnerProvider) +} + +func setUserProviders(ownerProvider ownerProvider) { + owner = ownerProvider +} + +type processPIDDetector struct{} +type processExecutableNameDetector struct{} +type processExecutablePathDetector struct{} +type processCommandArgsDetector struct{} +type processOwnerDetector struct{} +type processRuntimeNameDetector struct{} +type processRuntimeVersionDetector struct{} +type processRuntimeDescriptionDetector struct{} + +// Detect returns a *Resource that describes the process identifier (PID) of the +// executing process. +func (processPIDDetector) Detect(ctx context.Context) (*Resource, error) { + return NewWithAttributes(semconv.SchemaURL, semconv.ProcessPID(pid())), nil +} + +// Detect returns a *Resource that describes the name of the process executable. +func (processExecutableNameDetector) Detect(ctx context.Context) (*Resource, error) { + executableName := filepath.Base(commandArgs()[0]) + + return NewWithAttributes(semconv.SchemaURL, semconv.ProcessExecutableName(executableName)), nil +} + +// Detect returns a *Resource that describes the full path of the process executable. +func (processExecutablePathDetector) Detect(ctx context.Context) (*Resource, error) { + executablePath, err := executablePath() + if err != nil { + return nil, err + } + + return NewWithAttributes(semconv.SchemaURL, semconv.ProcessExecutablePath(executablePath)), nil +} + +// Detect returns a *Resource that describes all the command arguments as received +// by the process. +func (processCommandArgsDetector) Detect(ctx context.Context) (*Resource, error) { + return NewWithAttributes(semconv.SchemaURL, semconv.ProcessCommandArgs(commandArgs()...)), nil +} + +// Detect returns a *Resource that describes the username of the user that owns the +// process. +func (processOwnerDetector) Detect(ctx context.Context) (*Resource, error) { + owner, err := owner() + if err != nil { + return nil, err + } + + return NewWithAttributes(semconv.SchemaURL, semconv.ProcessOwner(owner.Username)), nil +} + +// Detect returns a *Resource that describes the name of the compiler used to compile +// this process image. +func (processRuntimeNameDetector) Detect(ctx context.Context) (*Resource, error) { + return NewWithAttributes(semconv.SchemaURL, semconv.ProcessRuntimeName(runtimeName())), nil +} + +// Detect returns a *Resource that describes the version of the runtime of this process. +func (processRuntimeVersionDetector) Detect(ctx context.Context) (*Resource, error) { + return NewWithAttributes(semconv.SchemaURL, semconv.ProcessRuntimeVersion(runtimeVersion())), nil +} + +// Detect returns a *Resource that describes the runtime of this process. +func (processRuntimeDescriptionDetector) Detect(ctx context.Context) (*Resource, error) { + runtimeDescription := fmt.Sprintf( + "go version %s %s/%s", runtimeVersion(), runtimeOS(), runtimeArch()) + + return NewWithAttributes( + semconv.SchemaURL, + semconv.ProcessRuntimeDescription(runtimeDescription), + ), nil +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/resource/resource.go b/vendor/go.opentelemetry.io/otel/sdk/resource/resource.go new file mode 100644 index 0000000000..139dc7e8f9 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/resource/resource.go @@ -0,0 +1,272 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resource // import "go.opentelemetry.io/otel/sdk/resource" + +import ( + "context" + "errors" + "sync" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" +) + +// Resource describes an entity about which identifying information +// and metadata is exposed. Resource is an immutable object, +// equivalent to a map from key to unique value. +// +// Resources should be passed and stored as pointers +// (`*resource.Resource`). The `nil` value is equivalent to an empty +// Resource. +type Resource struct { + attrs attribute.Set + schemaURL string +} + +var ( + emptyResource Resource + defaultResource *Resource + defaultResourceOnce sync.Once +) + +var errMergeConflictSchemaURL = errors.New("cannot merge resource due to conflicting Schema URL") + +// New returns a Resource combined from the user-provided detectors. +func New(ctx context.Context, opts ...Option) (*Resource, error) { + cfg := config{} + for _, opt := range opts { + cfg = opt.apply(cfg) + } + + r := &Resource{schemaURL: cfg.schemaURL} + return r, detect(ctx, r, cfg.detectors) +} + +// NewWithAttributes creates a resource from attrs and associates the resource with a +// schema URL. If attrs contains duplicate keys, the last value will be used. If attrs +// contains any invalid items those items will be dropped. The attrs are assumed to be +// in a schema identified by schemaURL. +func NewWithAttributes(schemaURL string, attrs ...attribute.KeyValue) *Resource { + resource := NewSchemaless(attrs...) + resource.schemaURL = schemaURL + return resource +} + +// NewSchemaless creates a resource from attrs. If attrs contains duplicate keys, +// the last value will be used. If attrs contains any invalid items those items will +// be dropped. The resource will not be associated with a schema URL. If the schema +// of the attrs is known use NewWithAttributes instead. +func NewSchemaless(attrs ...attribute.KeyValue) *Resource { + if len(attrs) == 0 { + return &emptyResource + } + + // Ensure attributes comply with the specification: + // https://github.com/open-telemetry/opentelemetry-specification/blob/v1.20.0/specification/common/README.md#attribute + s, _ := attribute.NewSetWithFiltered(attrs, func(kv attribute.KeyValue) bool { + return kv.Valid() + }) + + // If attrs only contains invalid entries do not allocate a new resource. + if s.Len() == 0 { + return &emptyResource + } + + return &Resource{attrs: s} //nolint +} + +// String implements the Stringer interface and provides a +// human-readable form of the resource. +// +// Avoid using this representation as the key in a map of resources, +// use Equivalent() as the key instead. +func (r *Resource) String() string { + if r == nil { + return "" + } + return r.attrs.Encoded(attribute.DefaultEncoder()) +} + +// MarshalLog is the marshaling function used by the logging system to represent this exporter. +func (r *Resource) MarshalLog() interface{} { + return struct { + Attributes attribute.Set + SchemaURL string + }{ + Attributes: r.attrs, + SchemaURL: r.schemaURL, + } +} + +// Attributes returns a copy of attributes from the resource in a sorted order. +// To avoid allocating a new slice, use an iterator. +func (r *Resource) Attributes() []attribute.KeyValue { + if r == nil { + r = Empty() + } + return r.attrs.ToSlice() +} + +// SchemaURL returns the schema URL associated with Resource r. +func (r *Resource) SchemaURL() string { + if r == nil { + return "" + } + return r.schemaURL +} + +// Iter returns an iterator of the Resource attributes. +// This is ideal to use if you do not want a copy of the attributes. +func (r *Resource) Iter() attribute.Iterator { + if r == nil { + r = Empty() + } + return r.attrs.Iter() +} + +// Equal returns true when a Resource is equivalent to this Resource. +func (r *Resource) Equal(eq *Resource) bool { + if r == nil { + r = Empty() + } + if eq == nil { + eq = Empty() + } + return r.Equivalent() == eq.Equivalent() +} + +// Merge creates a new resource by combining resource a and b. +// +// If there are common keys between resource a and b, then the value +// from resource b will overwrite the value from resource a, even +// if resource b's value is empty. +// +// The SchemaURL of the resources will be merged according to the spec rules: +// https://github.com/open-telemetry/opentelemetry-specification/blob/v1.20.0/specification/resource/sdk.md#merge +// If the resources have different non-empty schemaURL an empty resource and an error +// will be returned. +func Merge(a, b *Resource) (*Resource, error) { + if a == nil && b == nil { + return Empty(), nil + } + if a == nil { + return b, nil + } + if b == nil { + return a, nil + } + + // Merge the schema URL. + var schemaURL string + switch true { + case a.schemaURL == "": + schemaURL = b.schemaURL + case b.schemaURL == "": + schemaURL = a.schemaURL + case a.schemaURL == b.schemaURL: + schemaURL = a.schemaURL + default: + return Empty(), errMergeConflictSchemaURL + } + + // Note: 'b' attributes will overwrite 'a' with last-value-wins in attribute.Key() + // Meaning this is equivalent to: append(a.Attributes(), b.Attributes()...) + mi := attribute.NewMergeIterator(b.Set(), a.Set()) + combine := make([]attribute.KeyValue, 0, a.Len()+b.Len()) + for mi.Next() { + combine = append(combine, mi.Attribute()) + } + merged := NewWithAttributes(schemaURL, combine...) + return merged, nil +} + +// Empty returns an instance of Resource with no attributes. It is +// equivalent to a `nil` Resource. +func Empty() *Resource { + return &emptyResource +} + +// Default returns an instance of Resource with a default +// "service.name" and OpenTelemetrySDK attributes. +func Default() *Resource { + defaultResourceOnce.Do(func() { + var err error + defaultResource, err = Detect( + context.Background(), + defaultServiceNameDetector{}, + fromEnv{}, + telemetrySDK{}, + ) + if err != nil { + otel.Handle(err) + } + // If Detect did not return a valid resource, fall back to emptyResource. + if defaultResource == nil { + defaultResource = &emptyResource + } + }) + return defaultResource +} + +// Environment returns an instance of Resource with attributes +// extracted from the OTEL_RESOURCE_ATTRIBUTES environment variable. +func Environment() *Resource { + detector := &fromEnv{} + resource, err := detector.Detect(context.Background()) + if err != nil { + otel.Handle(err) + } + return resource +} + +// Equivalent returns an object that can be compared for equality +// between two resources. This value is suitable for use as a key in +// a map. +func (r *Resource) Equivalent() attribute.Distinct { + return r.Set().Equivalent() +} + +// Set returns the equivalent *attribute.Set of this resource's attributes. +func (r *Resource) Set() *attribute.Set { + if r == nil { + r = Empty() + } + return &r.attrs +} + +// MarshalJSON encodes the resource attributes as a JSON list of { "Key": +// "...", "Value": ... } pairs in order sorted by key. +func (r *Resource) MarshalJSON() ([]byte, error) { + if r == nil { + r = Empty() + } + return r.attrs.MarshalJSON() +} + +// Len returns the number of unique key-values in this Resource. +func (r *Resource) Len() int { + if r == nil { + return 0 + } + return r.attrs.Len() +} + +// Encoded returns an encoded representation of the resource. +func (r *Resource) Encoded(enc attribute.Encoder) string { + if r == nil { + return "" + } + return r.attrs.Encoded(enc) +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/trace/batch_span_processor.go b/vendor/go.opentelemetry.io/otel/sdk/trace/batch_span_processor.go new file mode 100644 index 0000000000..43d5b04230 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/trace/batch_span_processor.go @@ -0,0 +1,432 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trace // import "go.opentelemetry.io/otel/sdk/trace" + +import ( + "context" + "runtime" + "sync" + "sync/atomic" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/internal/global" + "go.opentelemetry.io/otel/sdk/internal/env" + "go.opentelemetry.io/otel/trace" +) + +// Defaults for BatchSpanProcessorOptions. +const ( + DefaultMaxQueueSize = 2048 + DefaultScheduleDelay = 5000 + DefaultExportTimeout = 30000 + DefaultMaxExportBatchSize = 512 +) + +// BatchSpanProcessorOption configures a BatchSpanProcessor. +type BatchSpanProcessorOption func(o *BatchSpanProcessorOptions) + +// BatchSpanProcessorOptions is configuration settings for a +// BatchSpanProcessor. +type BatchSpanProcessorOptions struct { + // MaxQueueSize is the maximum queue size to buffer spans for delayed processing. If the + // queue gets full it drops the spans. Use BlockOnQueueFull to change this behavior. + // The default value of MaxQueueSize is 2048. + MaxQueueSize int + + // BatchTimeout is the maximum duration for constructing a batch. Processor + // forcefully sends available spans when timeout is reached. + // The default value of BatchTimeout is 5000 msec. + BatchTimeout time.Duration + + // ExportTimeout specifies the maximum duration for exporting spans. If the timeout + // is reached, the export will be cancelled. + // The default value of ExportTimeout is 30000 msec. + ExportTimeout time.Duration + + // MaxExportBatchSize is the maximum number of spans to process in a single batch. + // If there are more than one batch worth of spans then it processes multiple batches + // of spans one batch after the other without any delay. + // The default value of MaxExportBatchSize is 512. + MaxExportBatchSize int + + // BlockOnQueueFull blocks onEnd() and onStart() method if the queue is full + // AND if BlockOnQueueFull is set to true. + // Blocking option should be used carefully as it can severely affect the performance of an + // application. + BlockOnQueueFull bool +} + +// batchSpanProcessor is a SpanProcessor that batches asynchronously-received +// spans and sends them to a trace.Exporter when complete. +type batchSpanProcessor struct { + e SpanExporter + o BatchSpanProcessorOptions + + queue chan ReadOnlySpan + dropped uint32 + + batch []ReadOnlySpan + batchMutex sync.Mutex + timer *time.Timer + stopWait sync.WaitGroup + stopOnce sync.Once + stopCh chan struct{} +} + +var _ SpanProcessor = (*batchSpanProcessor)(nil) + +// NewBatchSpanProcessor creates a new SpanProcessor that will send completed +// span batches to the exporter with the supplied options. +// +// If the exporter is nil, the span processor will perform no action. +func NewBatchSpanProcessor(exporter SpanExporter, options ...BatchSpanProcessorOption) SpanProcessor { + maxQueueSize := env.BatchSpanProcessorMaxQueueSize(DefaultMaxQueueSize) + maxExportBatchSize := env.BatchSpanProcessorMaxExportBatchSize(DefaultMaxExportBatchSize) + + if maxExportBatchSize > maxQueueSize { + if DefaultMaxExportBatchSize > maxQueueSize { + maxExportBatchSize = maxQueueSize + } else { + maxExportBatchSize = DefaultMaxExportBatchSize + } + } + + o := BatchSpanProcessorOptions{ + BatchTimeout: time.Duration(env.BatchSpanProcessorScheduleDelay(DefaultScheduleDelay)) * time.Millisecond, + ExportTimeout: time.Duration(env.BatchSpanProcessorExportTimeout(DefaultExportTimeout)) * time.Millisecond, + MaxQueueSize: maxQueueSize, + MaxExportBatchSize: maxExportBatchSize, + } + for _, opt := range options { + opt(&o) + } + bsp := &batchSpanProcessor{ + e: exporter, + o: o, + batch: make([]ReadOnlySpan, 0, o.MaxExportBatchSize), + timer: time.NewTimer(o.BatchTimeout), + queue: make(chan ReadOnlySpan, o.MaxQueueSize), + stopCh: make(chan struct{}), + } + + bsp.stopWait.Add(1) + go func() { + defer bsp.stopWait.Done() + bsp.processQueue() + bsp.drainQueue() + }() + + return bsp +} + +// OnStart method does nothing. +func (bsp *batchSpanProcessor) OnStart(parent context.Context, s ReadWriteSpan) {} + +// OnEnd method enqueues a ReadOnlySpan for later processing. +func (bsp *batchSpanProcessor) OnEnd(s ReadOnlySpan) { + // Do not enqueue spans if we are just going to drop them. + if bsp.e == nil { + return + } + bsp.enqueue(s) +} + +// Shutdown flushes the queue and waits until all spans are processed. +// It only executes once. Subsequent call does nothing. +func (bsp *batchSpanProcessor) Shutdown(ctx context.Context) error { + var err error + bsp.stopOnce.Do(func() { + wait := make(chan struct{}) + go func() { + close(bsp.stopCh) + bsp.stopWait.Wait() + if bsp.e != nil { + if err := bsp.e.Shutdown(ctx); err != nil { + otel.Handle(err) + } + } + close(wait) + }() + // Wait until the wait group is done or the context is cancelled + select { + case <-wait: + case <-ctx.Done(): + err = ctx.Err() + } + }) + return err +} + +type forceFlushSpan struct { + ReadOnlySpan + flushed chan struct{} +} + +func (f forceFlushSpan) SpanContext() trace.SpanContext { + return trace.NewSpanContext(trace.SpanContextConfig{TraceFlags: trace.FlagsSampled}) +} + +// ForceFlush exports all ended spans that have not yet been exported. +func (bsp *batchSpanProcessor) ForceFlush(ctx context.Context) error { + var err error + if bsp.e != nil { + flushCh := make(chan struct{}) + if bsp.enqueueBlockOnQueueFull(ctx, forceFlushSpan{flushed: flushCh}) { + select { + case <-flushCh: + // Processed any items in queue prior to ForceFlush being called + case <-ctx.Done(): + return ctx.Err() + } + } + + wait := make(chan error) + go func() { + wait <- bsp.exportSpans(ctx) + close(wait) + }() + // Wait until the export is finished or the context is cancelled/timed out + select { + case err = <-wait: + case <-ctx.Done(): + err = ctx.Err() + } + } + return err +} + +// WithMaxQueueSize returns a BatchSpanProcessorOption that configures the +// maximum queue size allowed for a BatchSpanProcessor. +func WithMaxQueueSize(size int) BatchSpanProcessorOption { + return func(o *BatchSpanProcessorOptions) { + o.MaxQueueSize = size + } +} + +// WithMaxExportBatchSize returns a BatchSpanProcessorOption that configures +// the maximum export batch size allowed for a BatchSpanProcessor. +func WithMaxExportBatchSize(size int) BatchSpanProcessorOption { + return func(o *BatchSpanProcessorOptions) { + o.MaxExportBatchSize = size + } +} + +// WithBatchTimeout returns a BatchSpanProcessorOption that configures the +// maximum delay allowed for a BatchSpanProcessor before it will export any +// held span (whether the queue is full or not). +func WithBatchTimeout(delay time.Duration) BatchSpanProcessorOption { + return func(o *BatchSpanProcessorOptions) { + o.BatchTimeout = delay + } +} + +// WithExportTimeout returns a BatchSpanProcessorOption that configures the +// amount of time a BatchSpanProcessor waits for an exporter to export before +// abandoning the export. +func WithExportTimeout(timeout time.Duration) BatchSpanProcessorOption { + return func(o *BatchSpanProcessorOptions) { + o.ExportTimeout = timeout + } +} + +// WithBlocking returns a BatchSpanProcessorOption that configures a +// BatchSpanProcessor to wait for enqueue operations to succeed instead of +// dropping data when the queue is full. +func WithBlocking() BatchSpanProcessorOption { + return func(o *BatchSpanProcessorOptions) { + o.BlockOnQueueFull = true + } +} + +// exportSpans is a subroutine of processing and draining the queue. +func (bsp *batchSpanProcessor) exportSpans(ctx context.Context) error { + bsp.timer.Reset(bsp.o.BatchTimeout) + + bsp.batchMutex.Lock() + defer bsp.batchMutex.Unlock() + + if bsp.o.ExportTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, bsp.o.ExportTimeout) + defer cancel() + } + + if l := len(bsp.batch); l > 0 { + global.Debug("exporting spans", "count", len(bsp.batch), "total_dropped", atomic.LoadUint32(&bsp.dropped)) + err := bsp.e.ExportSpans(ctx, bsp.batch) + + // A new batch is always created after exporting, even if the batch failed to be exported. + // + // It is up to the exporter to implement any type of retry logic if a batch is failing + // to be exported, since it is specific to the protocol and backend being sent to. + bsp.batch = bsp.batch[:0] + + if err != nil { + return err + } + } + return nil +} + +// processQueue removes spans from the `queue` channel until processor +// is shut down. It calls the exporter in batches of up to MaxExportBatchSize +// waiting up to BatchTimeout to form a batch. +func (bsp *batchSpanProcessor) processQueue() { + defer bsp.timer.Stop() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for { + select { + case <-bsp.stopCh: + return + case <-bsp.timer.C: + if err := bsp.exportSpans(ctx); err != nil { + otel.Handle(err) + } + case sd := <-bsp.queue: + if ffs, ok := sd.(forceFlushSpan); ok { + close(ffs.flushed) + continue + } + bsp.batchMutex.Lock() + bsp.batch = append(bsp.batch, sd) + shouldExport := len(bsp.batch) >= bsp.o.MaxExportBatchSize + bsp.batchMutex.Unlock() + if shouldExport { + if !bsp.timer.Stop() { + <-bsp.timer.C + } + if err := bsp.exportSpans(ctx); err != nil { + otel.Handle(err) + } + } + } + } +} + +// drainQueue awaits the any caller that had added to bsp.stopWait +// to finish the enqueue, then exports the final batch. +func (bsp *batchSpanProcessor) drainQueue() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for { + select { + case sd := <-bsp.queue: + if sd == nil { + if err := bsp.exportSpans(ctx); err != nil { + otel.Handle(err) + } + return + } + + bsp.batchMutex.Lock() + bsp.batch = append(bsp.batch, sd) + shouldExport := len(bsp.batch) == bsp.o.MaxExportBatchSize + bsp.batchMutex.Unlock() + + if shouldExport { + if err := bsp.exportSpans(ctx); err != nil { + otel.Handle(err) + } + } + default: + close(bsp.queue) + } + } +} + +func (bsp *batchSpanProcessor) enqueue(sd ReadOnlySpan) { + ctx := context.TODO() + if bsp.o.BlockOnQueueFull { + bsp.enqueueBlockOnQueueFull(ctx, sd) + } else { + bsp.enqueueDrop(ctx, sd) + } +} + +func recoverSendOnClosedChan() { + x := recover() + switch err := x.(type) { + case nil: + return + case runtime.Error: + if err.Error() == "send on closed channel" { + return + } + } + panic(x) +} + +func (bsp *batchSpanProcessor) enqueueBlockOnQueueFull(ctx context.Context, sd ReadOnlySpan) bool { + if !sd.SpanContext().IsSampled() { + return false + } + + // This ensures the bsp.queue<- below does not panic as the + // processor shuts down. + defer recoverSendOnClosedChan() + + select { + case <-bsp.stopCh: + return false + default: + } + + select { + case bsp.queue <- sd: + return true + case <-ctx.Done(): + return false + } +} + +func (bsp *batchSpanProcessor) enqueueDrop(ctx context.Context, sd ReadOnlySpan) bool { + if !sd.SpanContext().IsSampled() { + return false + } + + // This ensures the bsp.queue<- below does not panic as the + // processor shuts down. + defer recoverSendOnClosedChan() + + select { + case <-bsp.stopCh: + return false + default: + } + + select { + case bsp.queue <- sd: + return true + default: + atomic.AddUint32(&bsp.dropped, 1) + } + return false +} + +// MarshalLog is the marshaling function used by the logging system to represent this exporter. +func (bsp *batchSpanProcessor) MarshalLog() interface{} { + return struct { + Type string + SpanExporter SpanExporter + Config BatchSpanProcessorOptions + }{ + Type: "BatchSpanProcessor", + SpanExporter: bsp.e, + Config: bsp.o, + } +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/trace/span.go b/vendor/go.opentelemetry.io/otel/sdk/trace/span.go new file mode 100644 index 0000000000..4fcca26e08 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/trace/span.go @@ -0,0 +1,828 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trace // import "go.opentelemetry.io/otel/sdk/trace" + +import ( + "context" + "fmt" + "reflect" + "runtime" + rt "runtime/trace" + "strings" + "sync" + "time" + "unicode/utf8" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/sdk/instrumentation" + "go.opentelemetry.io/otel/sdk/internal" + "go.opentelemetry.io/otel/sdk/resource" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" + "go.opentelemetry.io/otel/trace" +) + +// ReadOnlySpan allows reading information from the data structure underlying a +// trace.Span. It is used in places where reading information from a span is +// necessary but changing the span isn't necessary or allowed. +// +// Warning: methods may be added to this interface in minor releases. +type ReadOnlySpan interface { + // Name returns the name of the span. + Name() string + // SpanContext returns the unique SpanContext that identifies the span. + SpanContext() trace.SpanContext + // Parent returns the unique SpanContext that identifies the parent of the + // span if one exists. If the span has no parent the returned SpanContext + // will be invalid. + Parent() trace.SpanContext + // SpanKind returns the role the span plays in a Trace. + SpanKind() trace.SpanKind + // StartTime returns the time the span started recording. + StartTime() time.Time + // EndTime returns the time the span stopped recording. It will be zero if + // the span has not ended. + EndTime() time.Time + // Attributes returns the defining attributes of the span. + // The order of the returned attributes is not guaranteed to be stable across invocations. + Attributes() []attribute.KeyValue + // Links returns all the links the span has to other spans. + Links() []Link + // Events returns all the events that occurred within in the spans + // lifetime. + Events() []Event + // Status returns the spans status. + Status() Status + // InstrumentationScope returns information about the instrumentation + // scope that created the span. + InstrumentationScope() instrumentation.Scope + // InstrumentationLibrary returns information about the instrumentation + // library that created the span. + // Deprecated: please use InstrumentationScope instead. + InstrumentationLibrary() instrumentation.Library + // Resource returns information about the entity that produced the span. + Resource() *resource.Resource + // DroppedAttributes returns the number of attributes dropped by the span + // due to limits being reached. + DroppedAttributes() int + // DroppedLinks returns the number of links dropped by the span due to + // limits being reached. + DroppedLinks() int + // DroppedEvents returns the number of events dropped by the span due to + // limits being reached. + DroppedEvents() int + // ChildSpanCount returns the count of spans that consider the span a + // direct parent. + ChildSpanCount() int + + // A private method to prevent users implementing the + // interface and so future additions to it will not + // violate compatibility. + private() +} + +// ReadWriteSpan exposes the same methods as trace.Span and in addition allows +// reading information from the underlying data structure. +// This interface exposes the union of the methods of trace.Span (which is a +// "write-only" span) and ReadOnlySpan. New methods for writing or reading span +// information should be added under trace.Span or ReadOnlySpan, respectively. +// +// Warning: methods may be added to this interface in minor releases. +type ReadWriteSpan interface { + trace.Span + ReadOnlySpan +} + +// recordingSpan is an implementation of the OpenTelemetry Span API +// representing the individual component of a trace that is sampled. +type recordingSpan struct { + // mu protects the contents of this span. + mu sync.Mutex + + // parent holds the parent span of this span as a trace.SpanContext. + parent trace.SpanContext + + // spanKind represents the kind of this span as a trace.SpanKind. + spanKind trace.SpanKind + + // name is the name of this span. + name string + + // startTime is the time at which this span was started. + startTime time.Time + + // endTime is the time at which this span was ended. It contains the zero + // value of time.Time until the span is ended. + endTime time.Time + + // status is the status of this span. + status Status + + // childSpanCount holds the number of child spans created for this span. + childSpanCount int + + // spanContext holds the SpanContext of this span. + spanContext trace.SpanContext + + // attributes is a collection of user provided key/values. The collection + // is constrained by a configurable maximum held by the parent + // TracerProvider. When additional attributes are added after this maximum + // is reached these attributes the user is attempting to add are dropped. + // This dropped number of attributes is tracked and reported in the + // ReadOnlySpan exported when the span ends. + attributes []attribute.KeyValue + droppedAttributes int + + // events are stored in FIFO queue capped by configured limit. + events evictedQueue + + // links are stored in FIFO queue capped by configured limit. + links evictedQueue + + // executionTracerTaskEnd ends the execution tracer span. + executionTracerTaskEnd func() + + // tracer is the SDK tracer that created this span. + tracer *tracer +} + +var _ ReadWriteSpan = (*recordingSpan)(nil) +var _ runtimeTracer = (*recordingSpan)(nil) + +// SpanContext returns the SpanContext of this span. +func (s *recordingSpan) SpanContext() trace.SpanContext { + if s == nil { + return trace.SpanContext{} + } + return s.spanContext +} + +// IsRecording returns if this span is being recorded. If this span has ended +// this will return false. +func (s *recordingSpan) IsRecording() bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + + return s.endTime.IsZero() +} + +// SetStatus sets the status of the Span in the form of a code and a +// description, overriding previous values set. The description is only +// included in the set status when the code is for an error. If this span is +// not being recorded than this method does nothing. +func (s *recordingSpan) SetStatus(code codes.Code, description string) { + if !s.IsRecording() { + return + } + s.mu.Lock() + defer s.mu.Unlock() + if s.status.Code > code { + return + } + + status := Status{Code: code} + if code == codes.Error { + status.Description = description + } + + s.status = status +} + +// SetAttributes sets attributes of this span. +// +// If a key from attributes already exists the value associated with that key +// will be overwritten with the value contained in attributes. +// +// If this span is not being recorded than this method does nothing. +// +// If adding attributes to the span would exceed the maximum amount of +// attributes the span is configured to have, the last added attributes will +// be dropped. +func (s *recordingSpan) SetAttributes(attributes ...attribute.KeyValue) { + if !s.IsRecording() { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + limit := s.tracer.provider.spanLimits.AttributeCountLimit + if limit == 0 { + // No attributes allowed. + s.droppedAttributes += len(attributes) + return + } + + // If adding these attributes could exceed the capacity of s perform a + // de-duplication and truncation while adding to avoid over allocation. + if limit > 0 && len(s.attributes)+len(attributes) > limit { + s.addOverCapAttrs(limit, attributes) + return + } + + // Otherwise, add without deduplication. When attributes are read they + // will be deduplicated, optimizing the operation. + for _, a := range attributes { + if !a.Valid() { + // Drop all invalid attributes. + s.droppedAttributes++ + continue + } + a = truncateAttr(s.tracer.provider.spanLimits.AttributeValueLengthLimit, a) + s.attributes = append(s.attributes, a) + } +} + +// addOverCapAttrs adds the attributes attrs to the span s while +// de-duplicating the attributes of s and attrs and dropping attributes that +// exceed the limit. +// +// This method assumes s.mu.Lock is held by the caller. +// +// This method should only be called when there is a possibility that adding +// attrs to s will exceed the limit. Otherwise, attrs should be added to s +// without checking for duplicates and all retrieval methods of the attributes +// for s will de-duplicate as needed. +// +// This method assumes limit is a value > 0. The argument should be validated +// by the caller. +func (s *recordingSpan) addOverCapAttrs(limit int, attrs []attribute.KeyValue) { + // In order to not allocate more capacity to s.attributes than needed, + // prune and truncate this addition of attributes while adding. + + // Do not set a capacity when creating this map. Benchmark testing has + // showed this to only add unused memory allocations in general use. + exists := make(map[attribute.Key]int) + s.dedupeAttrsFromRecord(&exists) + + // Now that s.attributes is deduplicated, adding unique attributes up to + // the capacity of s will not over allocate s.attributes. + for _, a := range attrs { + if !a.Valid() { + // Drop all invalid attributes. + s.droppedAttributes++ + continue + } + + if idx, ok := exists[a.Key]; ok { + // Perform all updates before dropping, even when at capacity. + s.attributes[idx] = a + continue + } + + if len(s.attributes) >= limit { + // Do not just drop all of the remaining attributes, make sure + // updates are checked and performed. + s.droppedAttributes++ + } else { + a = truncateAttr(s.tracer.provider.spanLimits.AttributeValueLengthLimit, a) + s.attributes = append(s.attributes, a) + exists[a.Key] = len(s.attributes) - 1 + } + } +} + +// truncateAttr returns a truncated version of attr. Only string and string +// slice attribute values are truncated. String values are truncated to at +// most a length of limit. Each string slice value is truncated in this fashion +// (the slice length itself is unaffected). +// +// No truncation is performed for a negative limit. +func truncateAttr(limit int, attr attribute.KeyValue) attribute.KeyValue { + if limit < 0 { + return attr + } + switch attr.Value.Type() { + case attribute.STRING: + if v := attr.Value.AsString(); len(v) > limit { + return attr.Key.String(safeTruncate(v, limit)) + } + case attribute.STRINGSLICE: + v := attr.Value.AsStringSlice() + for i := range v { + if len(v[i]) > limit { + v[i] = safeTruncate(v[i], limit) + } + } + return attr.Key.StringSlice(v) + } + return attr +} + +// safeTruncate truncates the string and guarantees valid UTF-8 is returned. +func safeTruncate(input string, limit int) string { + if trunc, ok := safeTruncateValidUTF8(input, limit); ok { + return trunc + } + trunc, _ := safeTruncateValidUTF8(strings.ToValidUTF8(input, ""), limit) + return trunc +} + +// safeTruncateValidUTF8 returns a copy of the input string safely truncated to +// limit. The truncation is ensured to occur at the bounds of complete UTF-8 +// characters. If invalid encoding of UTF-8 is encountered, input is returned +// with false, otherwise, the truncated input will be returned with true. +func safeTruncateValidUTF8(input string, limit int) (string, bool) { + for cnt := 0; cnt <= limit; { + r, size := utf8.DecodeRuneInString(input[cnt:]) + if r == utf8.RuneError { + return input, false + } + + if cnt+size > limit { + return input[:cnt], true + } + cnt += size + } + return input, true +} + +// End ends the span. This method does nothing if the span is already ended or +// is not being recorded. +// +// The only SpanOption currently supported is WithTimestamp which will set the +// end time for a Span's life-cycle. +// +// If this method is called while panicking an error event is added to the +// Span before ending it and the panic is continued. +func (s *recordingSpan) End(options ...trace.SpanEndOption) { + // Do not start by checking if the span is being recorded which requires + // acquiring a lock. Make a minimal check that the span is not nil. + if s == nil { + return + } + + // Store the end time as soon as possible to avoid artificially increasing + // the span's duration in case some operation below takes a while. + et := internal.MonotonicEndTime(s.startTime) + + // Do relative expensive check now that we have an end time and see if we + // need to do any more processing. + if !s.IsRecording() { + return + } + + config := trace.NewSpanEndConfig(options...) + if recovered := recover(); recovered != nil { + // Record but don't stop the panic. + defer panic(recovered) + opts := []trace.EventOption{ + trace.WithAttributes( + semconv.ExceptionType(typeStr(recovered)), + semconv.ExceptionMessage(fmt.Sprint(recovered)), + ), + } + + if config.StackTrace() { + opts = append(opts, trace.WithAttributes( + semconv.ExceptionStacktrace(recordStackTrace()), + )) + } + + s.addEvent(semconv.ExceptionEventName, opts...) + } + + if s.executionTracerTaskEnd != nil { + s.executionTracerTaskEnd() + } + + s.mu.Lock() + // Setting endTime to non-zero marks the span as ended and not recording. + if config.Timestamp().IsZero() { + s.endTime = et + } else { + s.endTime = config.Timestamp() + } + s.mu.Unlock() + + sps := s.tracer.provider.getSpanProcessors() + if len(sps) == 0 { + return + } + snap := s.snapshot() + for _, sp := range sps { + sp.sp.OnEnd(snap) + } +} + +// RecordError will record err as a span event for this span. An additional call to +// SetStatus is required if the Status of the Span should be set to Error, this method +// does not change the Span status. If this span is not being recorded or err is nil +// than this method does nothing. +func (s *recordingSpan) RecordError(err error, opts ...trace.EventOption) { + if s == nil || err == nil || !s.IsRecording() { + return + } + + opts = append(opts, trace.WithAttributes( + semconv.ExceptionType(typeStr(err)), + semconv.ExceptionMessage(err.Error()), + )) + + c := trace.NewEventConfig(opts...) + if c.StackTrace() { + opts = append(opts, trace.WithAttributes( + semconv.ExceptionStacktrace(recordStackTrace()), + )) + } + + s.addEvent(semconv.ExceptionEventName, opts...) +} + +func typeStr(i interface{}) string { + t := reflect.TypeOf(i) + if t.PkgPath() == "" && t.Name() == "" { + // Likely a builtin type. + return t.String() + } + return fmt.Sprintf("%s.%s", t.PkgPath(), t.Name()) +} + +func recordStackTrace() string { + stackTrace := make([]byte, 2048) + n := runtime.Stack(stackTrace, false) + + return string(stackTrace[0:n]) +} + +// AddEvent adds an event with the provided name and options. If this span is +// not being recorded than this method does nothing. +func (s *recordingSpan) AddEvent(name string, o ...trace.EventOption) { + if !s.IsRecording() { + return + } + s.addEvent(name, o...) +} + +func (s *recordingSpan) addEvent(name string, o ...trace.EventOption) { + c := trace.NewEventConfig(o...) + e := Event{Name: name, Attributes: c.Attributes(), Time: c.Timestamp()} + + // Discard attributes over limit. + limit := s.tracer.provider.spanLimits.AttributePerEventCountLimit + if limit == 0 { + // Drop all attributes. + e.DroppedAttributeCount = len(e.Attributes) + e.Attributes = nil + } else if limit > 0 && len(e.Attributes) > limit { + // Drop over capacity. + e.DroppedAttributeCount = len(e.Attributes) - limit + e.Attributes = e.Attributes[:limit] + } + + s.mu.Lock() + s.events.add(e) + s.mu.Unlock() +} + +// SetName sets the name of this span. If this span is not being recorded than +// this method does nothing. +func (s *recordingSpan) SetName(name string) { + if !s.IsRecording() { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + s.name = name +} + +// Name returns the name of this span. +func (s *recordingSpan) Name() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.name +} + +// Name returns the SpanContext of this span's parent span. +func (s *recordingSpan) Parent() trace.SpanContext { + s.mu.Lock() + defer s.mu.Unlock() + return s.parent +} + +// SpanKind returns the SpanKind of this span. +func (s *recordingSpan) SpanKind() trace.SpanKind { + s.mu.Lock() + defer s.mu.Unlock() + return s.spanKind +} + +// StartTime returns the time this span started. +func (s *recordingSpan) StartTime() time.Time { + s.mu.Lock() + defer s.mu.Unlock() + return s.startTime +} + +// EndTime returns the time this span ended. For spans that have not yet +// ended, the returned value will be the zero value of time.Time. +func (s *recordingSpan) EndTime() time.Time { + s.mu.Lock() + defer s.mu.Unlock() + return s.endTime +} + +// Attributes returns the attributes of this span. +// +// The order of the returned attributes is not guaranteed to be stable. +func (s *recordingSpan) Attributes() []attribute.KeyValue { + s.mu.Lock() + defer s.mu.Unlock() + s.dedupeAttrs() + return s.attributes +} + +// dedupeAttrs deduplicates the attributes of s to fit capacity. +// +// This method assumes s.mu.Lock is held by the caller. +func (s *recordingSpan) dedupeAttrs() { + // Do not set a capacity when creating this map. Benchmark testing has + // showed this to only add unused memory allocations in general use. + exists := make(map[attribute.Key]int) + s.dedupeAttrsFromRecord(&exists) +} + +// dedupeAttrsFromRecord deduplicates the attributes of s to fit capacity +// using record as the record of unique attribute keys to their index. +// +// This method assumes s.mu.Lock is held by the caller. +func (s *recordingSpan) dedupeAttrsFromRecord(record *map[attribute.Key]int) { + // Use the fact that slices share the same backing array. + unique := s.attributes[:0] + for _, a := range s.attributes { + if idx, ok := (*record)[a.Key]; ok { + unique[idx] = a + } else { + unique = append(unique, a) + (*record)[a.Key] = len(unique) - 1 + } + } + // s.attributes have element types of attribute.KeyValue. These types are + // not pointers and they themselves do not contain pointer fields, + // therefore the duplicate values do not need to be zeroed for them to be + // garbage collected. + s.attributes = unique +} + +// Links returns the links of this span. +func (s *recordingSpan) Links() []Link { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.links.queue) == 0 { + return []Link{} + } + return s.interfaceArrayToLinksArray() +} + +// Events returns the events of this span. +func (s *recordingSpan) Events() []Event { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.events.queue) == 0 { + return []Event{} + } + return s.interfaceArrayToEventArray() +} + +// Status returns the status of this span. +func (s *recordingSpan) Status() Status { + s.mu.Lock() + defer s.mu.Unlock() + return s.status +} + +// InstrumentationScope returns the instrumentation.Scope associated with +// the Tracer that created this span. +func (s *recordingSpan) InstrumentationScope() instrumentation.Scope { + s.mu.Lock() + defer s.mu.Unlock() + return s.tracer.instrumentationScope +} + +// InstrumentationLibrary returns the instrumentation.Library associated with +// the Tracer that created this span. +func (s *recordingSpan) InstrumentationLibrary() instrumentation.Library { + s.mu.Lock() + defer s.mu.Unlock() + return s.tracer.instrumentationScope +} + +// Resource returns the Resource associated with the Tracer that created this +// span. +func (s *recordingSpan) Resource() *resource.Resource { + s.mu.Lock() + defer s.mu.Unlock() + return s.tracer.provider.resource +} + +func (s *recordingSpan) addLink(link trace.Link) { + if !s.IsRecording() || !link.SpanContext.IsValid() { + return + } + + l := Link{SpanContext: link.SpanContext, Attributes: link.Attributes} + + // Discard attributes over limit. + limit := s.tracer.provider.spanLimits.AttributePerLinkCountLimit + if limit == 0 { + // Drop all attributes. + l.DroppedAttributeCount = len(l.Attributes) + l.Attributes = nil + } else if limit > 0 && len(l.Attributes) > limit { + l.DroppedAttributeCount = len(l.Attributes) - limit + l.Attributes = l.Attributes[:limit] + } + + s.mu.Lock() + s.links.add(l) + s.mu.Unlock() +} + +// DroppedAttributes returns the number of attributes dropped by the span +// due to limits being reached. +func (s *recordingSpan) DroppedAttributes() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.droppedAttributes +} + +// DroppedLinks returns the number of links dropped by the span due to limits +// being reached. +func (s *recordingSpan) DroppedLinks() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.links.droppedCount +} + +// DroppedEvents returns the number of events dropped by the span due to +// limits being reached. +func (s *recordingSpan) DroppedEvents() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.events.droppedCount +} + +// ChildSpanCount returns the count of spans that consider the span a +// direct parent. +func (s *recordingSpan) ChildSpanCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.childSpanCount +} + +// TracerProvider returns a trace.TracerProvider that can be used to generate +// additional Spans on the same telemetry pipeline as the current Span. +func (s *recordingSpan) TracerProvider() trace.TracerProvider { + return s.tracer.provider +} + +// snapshot creates a read-only copy of the current state of the span. +func (s *recordingSpan) snapshot() ReadOnlySpan { + var sd snapshot + s.mu.Lock() + defer s.mu.Unlock() + + sd.endTime = s.endTime + sd.instrumentationScope = s.tracer.instrumentationScope + sd.name = s.name + sd.parent = s.parent + sd.resource = s.tracer.provider.resource + sd.spanContext = s.spanContext + sd.spanKind = s.spanKind + sd.startTime = s.startTime + sd.status = s.status + sd.childSpanCount = s.childSpanCount + + if len(s.attributes) > 0 { + s.dedupeAttrs() + sd.attributes = s.attributes + } + sd.droppedAttributeCount = s.droppedAttributes + if len(s.events.queue) > 0 { + sd.events = s.interfaceArrayToEventArray() + sd.droppedEventCount = s.events.droppedCount + } + if len(s.links.queue) > 0 { + sd.links = s.interfaceArrayToLinksArray() + sd.droppedLinkCount = s.links.droppedCount + } + return &sd +} + +func (s *recordingSpan) interfaceArrayToLinksArray() []Link { + linkArr := make([]Link, 0) + for _, value := range s.links.queue { + linkArr = append(linkArr, value.(Link)) + } + return linkArr +} + +func (s *recordingSpan) interfaceArrayToEventArray() []Event { + eventArr := make([]Event, 0) + for _, value := range s.events.queue { + eventArr = append(eventArr, value.(Event)) + } + return eventArr +} + +func (s *recordingSpan) addChild() { + if !s.IsRecording() { + return + } + s.mu.Lock() + s.childSpanCount++ + s.mu.Unlock() +} + +func (*recordingSpan) private() {} + +// runtimeTrace starts a "runtime/trace".Task for the span and returns a +// context containing the task. +func (s *recordingSpan) runtimeTrace(ctx context.Context) context.Context { + if !rt.IsEnabled() { + // Avoid additional overhead if runtime/trace is not enabled. + return ctx + } + nctx, task := rt.NewTask(ctx, s.name) + + s.mu.Lock() + s.executionTracerTaskEnd = task.End + s.mu.Unlock() + + return nctx +} + +// nonRecordingSpan is a minimal implementation of the OpenTelemetry Span API +// that wraps a SpanContext. It performs no operations other than to return +// the wrapped SpanContext or TracerProvider that created it. +type nonRecordingSpan struct { + // tracer is the SDK tracer that created this span. + tracer *tracer + sc trace.SpanContext +} + +var _ trace.Span = nonRecordingSpan{} + +// SpanContext returns the wrapped SpanContext. +func (s nonRecordingSpan) SpanContext() trace.SpanContext { return s.sc } + +// IsRecording always returns false. +func (nonRecordingSpan) IsRecording() bool { return false } + +// SetStatus does nothing. +func (nonRecordingSpan) SetStatus(codes.Code, string) {} + +// SetError does nothing. +func (nonRecordingSpan) SetError(bool) {} + +// SetAttributes does nothing. +func (nonRecordingSpan) SetAttributes(...attribute.KeyValue) {} + +// End does nothing. +func (nonRecordingSpan) End(...trace.SpanEndOption) {} + +// RecordError does nothing. +func (nonRecordingSpan) RecordError(error, ...trace.EventOption) {} + +// AddEvent does nothing. +func (nonRecordingSpan) AddEvent(string, ...trace.EventOption) {} + +// SetName does nothing. +func (nonRecordingSpan) SetName(string) {} + +// TracerProvider returns the trace.TracerProvider that provided the Tracer +// that created this span. +func (s nonRecordingSpan) TracerProvider() trace.TracerProvider { return s.tracer.provider } + +func isRecording(s SamplingResult) bool { + return s.Decision == RecordOnly || s.Decision == RecordAndSample +} + +func isSampled(s SamplingResult) bool { + return s.Decision == RecordAndSample +} + +// Status is the classified state of a Span. +type Status struct { + // Code is an identifier of a Spans state classification. + Code codes.Code + // Description is a user hint about why that status was set. It is only + // applicable when Code is Error. + Description string +} diff --git a/vendor/go.opentelemetry.io/otel/sdk/version.go b/vendor/go.opentelemetry.io/otel/sdk/version.go new file mode 100644 index 0000000000..dbef90b0df --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/sdk/version.go @@ -0,0 +1,20 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sdk // import "go.opentelemetry.io/otel/sdk" + +// Version is the current release version of the OpenTelemetry SDK in use. +func Version() string { + return "1.16.0" +} diff --git a/vendor/go.opentelemetry.io/otel/semconv/internal/v2/http.go b/vendor/go.opentelemetry.io/otel/semconv/internal/v2/http.go new file mode 100644 index 0000000000..12d6b520f5 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/semconv/internal/v2/http.go @@ -0,0 +1,404 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal // import "go.opentelemetry.io/otel/semconv/internal/v2" + +import ( + "fmt" + "net/http" + "strings" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" +) + +// HTTPConv are the HTTP semantic convention attributes defined for a version +// of the OpenTelemetry specification. +type HTTPConv struct { + NetConv *NetConv + + EnduserIDKey attribute.Key + HTTPClientIPKey attribute.Key + HTTPFlavorKey attribute.Key + HTTPMethodKey attribute.Key + HTTPRequestContentLengthKey attribute.Key + HTTPResponseContentLengthKey attribute.Key + HTTPRouteKey attribute.Key + HTTPSchemeHTTP attribute.KeyValue + HTTPSchemeHTTPS attribute.KeyValue + HTTPStatusCodeKey attribute.Key + HTTPTargetKey attribute.Key + HTTPURLKey attribute.Key + HTTPUserAgentKey attribute.Key +} + +// ClientResponse returns attributes for an HTTP response received by a client +// from a server. The following attributes are returned if the related values +// are defined in resp: "http.status.code", "http.response_content_length". +// +// This does not add all OpenTelemetry required attributes for an HTTP event, +// it assumes ClientRequest was used to create the span with a complete set of +// attributes. If a complete set of attributes can be generated using the +// request contained in resp. For example: +// +// append(ClientResponse(resp), ClientRequest(resp.Request)...) +func (c *HTTPConv) ClientResponse(resp *http.Response) []attribute.KeyValue { + var n int + if resp.StatusCode > 0 { + n++ + } + if resp.ContentLength > 0 { + n++ + } + + attrs := make([]attribute.KeyValue, 0, n) + if resp.StatusCode > 0 { + attrs = append(attrs, c.HTTPStatusCodeKey.Int(resp.StatusCode)) + } + if resp.ContentLength > 0 { + attrs = append(attrs, c.HTTPResponseContentLengthKey.Int(int(resp.ContentLength))) + } + return attrs +} + +// ClientRequest returns attributes for an HTTP request made by a client. The +// following attributes are always returned: "http.url", "http.flavor", +// "http.method", "net.peer.name". The following attributes are returned if the +// related values are defined in req: "net.peer.port", "http.user_agent", +// "http.request_content_length", "enduser.id". +func (c *HTTPConv) ClientRequest(req *http.Request) []attribute.KeyValue { + n := 3 // URL, peer name, proto, and method. + var h string + if req.URL != nil { + h = req.URL.Host + } + peer, p := firstHostPort(h, req.Header.Get("Host")) + port := requiredHTTPPort(req.URL != nil && req.URL.Scheme == "https", p) + if port > 0 { + n++ + } + useragent := req.UserAgent() + if useragent != "" { + n++ + } + if req.ContentLength > 0 { + n++ + } + userID, _, hasUserID := req.BasicAuth() + if hasUserID { + n++ + } + attrs := make([]attribute.KeyValue, 0, n) + + attrs = append(attrs, c.method(req.Method)) + attrs = append(attrs, c.proto(req.Proto)) + + var u string + if req.URL != nil { + // Remove any username/password info that may be in the URL. + userinfo := req.URL.User + req.URL.User = nil + u = req.URL.String() + // Restore any username/password info that was removed. + req.URL.User = userinfo + } + attrs = append(attrs, c.HTTPURLKey.String(u)) + + attrs = append(attrs, c.NetConv.PeerName(peer)) + if port > 0 { + attrs = append(attrs, c.NetConv.PeerPort(port)) + } + + if useragent != "" { + attrs = append(attrs, c.HTTPUserAgentKey.String(useragent)) + } + + if l := req.ContentLength; l > 0 { + attrs = append(attrs, c.HTTPRequestContentLengthKey.Int64(l)) + } + + if hasUserID { + attrs = append(attrs, c.EnduserIDKey.String(userID)) + } + + return attrs +} + +// ServerRequest returns attributes for an HTTP request received by a server. +// +// The server must be the primary server name if it is known. For example this +// would be the ServerName directive +// (https://httpd.apache.org/docs/2.4/mod/core.html#servername) for an Apache +// server, and the server_name directive +// (http://nginx.org/en/docs/http/ngx_http_core_module.html#server_name) for an +// nginx server. More generically, the primary server name would be the host +// header value that matches the default virtual host of an HTTP server. It +// should include the host identifier and if a port is used to route to the +// server that port identifier should be included as an appropriate port +// suffix. +// +// If the primary server name is not known, server should be an empty string. +// The req Host will be used to determine the server instead. +// +// The following attributes are always returned: "http.method", "http.scheme", +// "http.flavor", "http.target", "net.host.name". The following attributes are +// returned if they related values are defined in req: "net.host.port", +// "net.sock.peer.addr", "net.sock.peer.port", "http.user_agent", "enduser.id", +// "http.client_ip". +func (c *HTTPConv) ServerRequest(server string, req *http.Request) []attribute.KeyValue { + // TODO: This currently does not add the specification required + // `http.target` attribute. It has too high of a cardinality to safely be + // added. An alternate should be added, or this comment removed, when it is + // addressed by the specification. If it is ultimately decided to continue + // not including the attribute, the HTTPTargetKey field of the HTTPConv + // should be removed as well. + + n := 4 // Method, scheme, proto, and host name. + var host string + var p int + if server == "" { + host, p = splitHostPort(req.Host) + } else { + // Prioritize the primary server name. + host, p = splitHostPort(server) + if p < 0 { + _, p = splitHostPort(req.Host) + } + } + hostPort := requiredHTTPPort(req.TLS != nil, p) + if hostPort > 0 { + n++ + } + peer, peerPort := splitHostPort(req.RemoteAddr) + if peer != "" { + n++ + if peerPort > 0 { + n++ + } + } + useragent := req.UserAgent() + if useragent != "" { + n++ + } + userID, _, hasUserID := req.BasicAuth() + if hasUserID { + n++ + } + clientIP := serverClientIP(req.Header.Get("X-Forwarded-For")) + if clientIP != "" { + n++ + } + attrs := make([]attribute.KeyValue, 0, n) + + attrs = append(attrs, c.method(req.Method)) + attrs = append(attrs, c.scheme(req.TLS != nil)) + attrs = append(attrs, c.proto(req.Proto)) + attrs = append(attrs, c.NetConv.HostName(host)) + + if hostPort > 0 { + attrs = append(attrs, c.NetConv.HostPort(hostPort)) + } + + if peer != "" { + // The Go HTTP server sets RemoteAddr to "IP:port", this will not be a + // file-path that would be interpreted with a sock family. + attrs = append(attrs, c.NetConv.SockPeerAddr(peer)) + if peerPort > 0 { + attrs = append(attrs, c.NetConv.SockPeerPort(peerPort)) + } + } + + if useragent != "" { + attrs = append(attrs, c.HTTPUserAgentKey.String(useragent)) + } + + if hasUserID { + attrs = append(attrs, c.EnduserIDKey.String(userID)) + } + + if clientIP != "" { + attrs = append(attrs, c.HTTPClientIPKey.String(clientIP)) + } + + return attrs +} + +func (c *HTTPConv) method(method string) attribute.KeyValue { + if method == "" { + return c.HTTPMethodKey.String(http.MethodGet) + } + return c.HTTPMethodKey.String(method) +} + +func (c *HTTPConv) scheme(https bool) attribute.KeyValue { // nolint:revive + if https { + return c.HTTPSchemeHTTPS + } + return c.HTTPSchemeHTTP +} + +func (c *HTTPConv) proto(proto string) attribute.KeyValue { + switch proto { + case "HTTP/1.0": + return c.HTTPFlavorKey.String("1.0") + case "HTTP/1.1": + return c.HTTPFlavorKey.String("1.1") + case "HTTP/2": + return c.HTTPFlavorKey.String("2.0") + case "HTTP/3": + return c.HTTPFlavorKey.String("3.0") + default: + return c.HTTPFlavorKey.String(proto) + } +} + +func serverClientIP(xForwardedFor string) string { + if idx := strings.Index(xForwardedFor, ","); idx >= 0 { + xForwardedFor = xForwardedFor[:idx] + } + return xForwardedFor +} + +func requiredHTTPPort(https bool, port int) int { // nolint:revive + if https { + if port > 0 && port != 443 { + return port + } + } else { + if port > 0 && port != 80 { + return port + } + } + return -1 +} + +// Return the request host and port from the first non-empty source. +func firstHostPort(source ...string) (host string, port int) { + for _, hostport := range source { + host, port = splitHostPort(hostport) + if host != "" || port > 0 { + break + } + } + return +} + +// RequestHeader returns the contents of h as OpenTelemetry attributes. +func (c *HTTPConv) RequestHeader(h http.Header) []attribute.KeyValue { + return c.header("http.request.header", h) +} + +// ResponseHeader returns the contents of h as OpenTelemetry attributes. +func (c *HTTPConv) ResponseHeader(h http.Header) []attribute.KeyValue { + return c.header("http.response.header", h) +} + +func (c *HTTPConv) header(prefix string, h http.Header) []attribute.KeyValue { + key := func(k string) attribute.Key { + k = strings.ToLower(k) + k = strings.ReplaceAll(k, "-", "_") + k = fmt.Sprintf("%s.%s", prefix, k) + return attribute.Key(k) + } + + attrs := make([]attribute.KeyValue, 0, len(h)) + for k, v := range h { + attrs = append(attrs, key(k).StringSlice(v)) + } + return attrs +} + +// ClientStatus returns a span status code and message for an HTTP status code +// value received by a client. +func (c *HTTPConv) ClientStatus(code int) (codes.Code, string) { + stat, valid := validateHTTPStatusCode(code) + if !valid { + return stat, fmt.Sprintf("Invalid HTTP status code %d", code) + } + return stat, "" +} + +// ServerStatus returns a span status code and message for an HTTP status code +// value returned by a server. Status codes in the 400-499 range are not +// returned as errors. +func (c *HTTPConv) ServerStatus(code int) (codes.Code, string) { + stat, valid := validateHTTPStatusCode(code) + if !valid { + return stat, fmt.Sprintf("Invalid HTTP status code %d", code) + } + + if code/100 == 4 { + return codes.Unset, "" + } + return stat, "" +} + +type codeRange struct { + fromInclusive int + toInclusive int +} + +func (r codeRange) contains(code int) bool { + return r.fromInclusive <= code && code <= r.toInclusive +} + +var validRangesPerCategory = map[int][]codeRange{ + 1: { + {http.StatusContinue, http.StatusEarlyHints}, + }, + 2: { + {http.StatusOK, http.StatusAlreadyReported}, + {http.StatusIMUsed, http.StatusIMUsed}, + }, + 3: { + {http.StatusMultipleChoices, http.StatusUseProxy}, + {http.StatusTemporaryRedirect, http.StatusPermanentRedirect}, + }, + 4: { + {http.StatusBadRequest, http.StatusTeapot}, // yes, teapot is so useful… + {http.StatusMisdirectedRequest, http.StatusUpgradeRequired}, + {http.StatusPreconditionRequired, http.StatusTooManyRequests}, + {http.StatusRequestHeaderFieldsTooLarge, http.StatusRequestHeaderFieldsTooLarge}, + {http.StatusUnavailableForLegalReasons, http.StatusUnavailableForLegalReasons}, + }, + 5: { + {http.StatusInternalServerError, http.StatusLoopDetected}, + {http.StatusNotExtended, http.StatusNetworkAuthenticationRequired}, + }, +} + +// validateHTTPStatusCode validates the HTTP status code and returns +// corresponding span status code. If the `code` is not a valid HTTP status +// code, returns span status Error and false. +func validateHTTPStatusCode(code int) (codes.Code, bool) { + category := code / 100 + ranges, ok := validRangesPerCategory[category] + if !ok { + return codes.Error, false + } + ok = false + for _, crange := range ranges { + ok = crange.contains(code) + if ok { + break + } + } + if !ok { + return codes.Error, false + } + if category > 0 && category < 4 { + return codes.Unset, true + } + return codes.Error, true +} diff --git a/vendor/go.opentelemetry.io/otel/semconv/internal/v2/net.go b/vendor/go.opentelemetry.io/otel/semconv/internal/v2/net.go new file mode 100644 index 0000000000..4a711133a0 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/semconv/internal/v2/net.go @@ -0,0 +1,324 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal // import "go.opentelemetry.io/otel/semconv/internal/v2" + +import ( + "net" + "strconv" + "strings" + + "go.opentelemetry.io/otel/attribute" +) + +// NetConv are the network semantic convention attributes defined for a version +// of the OpenTelemetry specification. +type NetConv struct { + NetHostNameKey attribute.Key + NetHostPortKey attribute.Key + NetPeerNameKey attribute.Key + NetPeerPortKey attribute.Key + NetSockFamilyKey attribute.Key + NetSockPeerAddrKey attribute.Key + NetSockPeerPortKey attribute.Key + NetSockHostAddrKey attribute.Key + NetSockHostPortKey attribute.Key + NetTransportOther attribute.KeyValue + NetTransportTCP attribute.KeyValue + NetTransportUDP attribute.KeyValue + NetTransportInProc attribute.KeyValue +} + +func (c *NetConv) Transport(network string) attribute.KeyValue { + switch network { + case "tcp", "tcp4", "tcp6": + return c.NetTransportTCP + case "udp", "udp4", "udp6": + return c.NetTransportUDP + case "unix", "unixgram", "unixpacket": + return c.NetTransportInProc + default: + // "ip:*", "ip4:*", and "ip6:*" all are considered other. + return c.NetTransportOther + } +} + +// Host returns attributes for a network host address. +func (c *NetConv) Host(address string) []attribute.KeyValue { + h, p := splitHostPort(address) + var n int + if h != "" { + n++ + if p > 0 { + n++ + } + } + + if n == 0 { + return nil + } + + attrs := make([]attribute.KeyValue, 0, n) + attrs = append(attrs, c.HostName(h)) + if p > 0 { + attrs = append(attrs, c.HostPort(int(p))) + } + return attrs +} + +// Server returns attributes for a network listener listening at address. See +// net.Listen for information about acceptable address values, address should +// be the same as the one used to create ln. If ln is nil, only network host +// attributes will be returned that describe address. Otherwise, the socket +// level information about ln will also be included. +func (c *NetConv) Server(address string, ln net.Listener) []attribute.KeyValue { + if ln == nil { + return c.Host(address) + } + + lAddr := ln.Addr() + if lAddr == nil { + return c.Host(address) + } + + hostName, hostPort := splitHostPort(address) + sockHostAddr, sockHostPort := splitHostPort(lAddr.String()) + network := lAddr.Network() + sockFamily := family(network, sockHostAddr) + + n := nonZeroStr(hostName, network, sockHostAddr, sockFamily) + n += positiveInt(hostPort, sockHostPort) + attr := make([]attribute.KeyValue, 0, n) + if hostName != "" { + attr = append(attr, c.HostName(hostName)) + if hostPort > 0 { + // Only if net.host.name is set should net.host.port be. + attr = append(attr, c.HostPort(hostPort)) + } + } + if network != "" { + attr = append(attr, c.Transport(network)) + } + if sockFamily != "" { + attr = append(attr, c.NetSockFamilyKey.String(sockFamily)) + } + if sockHostAddr != "" { + attr = append(attr, c.NetSockHostAddrKey.String(sockHostAddr)) + if sockHostPort > 0 { + // Only if net.sock.host.addr is set should net.sock.host.port be. + attr = append(attr, c.NetSockHostPortKey.Int(sockHostPort)) + } + } + return attr +} + +func (c *NetConv) HostName(name string) attribute.KeyValue { + return c.NetHostNameKey.String(name) +} + +func (c *NetConv) HostPort(port int) attribute.KeyValue { + return c.NetHostPortKey.Int(port) +} + +// Client returns attributes for a client network connection to address. See +// net.Dial for information about acceptable address values, address should be +// the same as the one used to create conn. If conn is nil, only network peer +// attributes will be returned that describe address. Otherwise, the socket +// level information about conn will also be included. +func (c *NetConv) Client(address string, conn net.Conn) []attribute.KeyValue { + if conn == nil { + return c.Peer(address) + } + + lAddr, rAddr := conn.LocalAddr(), conn.RemoteAddr() + + var network string + switch { + case lAddr != nil: + network = lAddr.Network() + case rAddr != nil: + network = rAddr.Network() + default: + return c.Peer(address) + } + + peerName, peerPort := splitHostPort(address) + var ( + sockFamily string + sockPeerAddr string + sockPeerPort int + sockHostAddr string + sockHostPort int + ) + + if lAddr != nil { + sockHostAddr, sockHostPort = splitHostPort(lAddr.String()) + } + + if rAddr != nil { + sockPeerAddr, sockPeerPort = splitHostPort(rAddr.String()) + } + + switch { + case sockHostAddr != "": + sockFamily = family(network, sockHostAddr) + case sockPeerAddr != "": + sockFamily = family(network, sockPeerAddr) + } + + n := nonZeroStr(peerName, network, sockPeerAddr, sockHostAddr, sockFamily) + n += positiveInt(peerPort, sockPeerPort, sockHostPort) + attr := make([]attribute.KeyValue, 0, n) + if peerName != "" { + attr = append(attr, c.PeerName(peerName)) + if peerPort > 0 { + // Only if net.peer.name is set should net.peer.port be. + attr = append(attr, c.PeerPort(peerPort)) + } + } + if network != "" { + attr = append(attr, c.Transport(network)) + } + if sockFamily != "" { + attr = append(attr, c.NetSockFamilyKey.String(sockFamily)) + } + if sockPeerAddr != "" { + attr = append(attr, c.NetSockPeerAddrKey.String(sockPeerAddr)) + if sockPeerPort > 0 { + // Only if net.sock.peer.addr is set should net.sock.peer.port be. + attr = append(attr, c.NetSockPeerPortKey.Int(sockPeerPort)) + } + } + if sockHostAddr != "" { + attr = append(attr, c.NetSockHostAddrKey.String(sockHostAddr)) + if sockHostPort > 0 { + // Only if net.sock.host.addr is set should net.sock.host.port be. + attr = append(attr, c.NetSockHostPortKey.Int(sockHostPort)) + } + } + return attr +} + +func family(network, address string) string { + switch network { + case "unix", "unixgram", "unixpacket": + return "unix" + default: + if ip := net.ParseIP(address); ip != nil { + if ip.To4() == nil { + return "inet6" + } + return "inet" + } + } + return "" +} + +func nonZeroStr(strs ...string) int { + var n int + for _, str := range strs { + if str != "" { + n++ + } + } + return n +} + +func positiveInt(ints ...int) int { + var n int + for _, i := range ints { + if i > 0 { + n++ + } + } + return n +} + +// Peer returns attributes for a network peer address. +func (c *NetConv) Peer(address string) []attribute.KeyValue { + h, p := splitHostPort(address) + var n int + if h != "" { + n++ + if p > 0 { + n++ + } + } + + if n == 0 { + return nil + } + + attrs := make([]attribute.KeyValue, 0, n) + attrs = append(attrs, c.PeerName(h)) + if p > 0 { + attrs = append(attrs, c.PeerPort(int(p))) + } + return attrs +} + +func (c *NetConv) PeerName(name string) attribute.KeyValue { + return c.NetPeerNameKey.String(name) +} + +func (c *NetConv) PeerPort(port int) attribute.KeyValue { + return c.NetPeerPortKey.Int(port) +} + +func (c *NetConv) SockPeerAddr(addr string) attribute.KeyValue { + return c.NetSockPeerAddrKey.String(addr) +} + +func (c *NetConv) SockPeerPort(port int) attribute.KeyValue { + return c.NetSockPeerPortKey.Int(port) +} + +// splitHostPort splits a network address hostport of the form "host", +// "host%zone", "[host]", "[host%zone], "host:port", "host%zone:port", +// "[host]:port", "[host%zone]:port", or ":port" into host or host%zone and +// port. +// +// An empty host is returned if it is not provided or unparsable. A negative +// port is returned if it is not provided or unparsable. +func splitHostPort(hostport string) (host string, port int) { + port = -1 + + if strings.HasPrefix(hostport, "[") { + addrEnd := strings.LastIndex(hostport, "]") + if addrEnd < 0 { + // Invalid hostport. + return + } + if i := strings.LastIndex(hostport[addrEnd:], ":"); i < 0 { + host = hostport[1:addrEnd] + return + } + } else { + if i := strings.LastIndex(hostport, ":"); i < 0 { + host = hostport + return + } + } + + host, pStr, err := net.SplitHostPort(hostport) + if err != nil { + return + } + + p, err := strconv.ParseUint(pStr, 10, 16) + if err != nil { + return + } + return host, int(p) +} diff --git a/vendor/go.opentelemetry.io/otel/semconv/v1.17.0/httpconv/http.go b/vendor/go.opentelemetry.io/otel/semconv/v1.17.0/httpconv/http.go new file mode 100644 index 0000000000..fc43808fe4 --- /dev/null +++ b/vendor/go.opentelemetry.io/otel/semconv/v1.17.0/httpconv/http.go @@ -0,0 +1,152 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httpconv provides OpenTelemetry HTTP semantic conventions for +// tracing telemetry. +package httpconv // import "go.opentelemetry.io/otel/semconv/v1.17.0/httpconv" + +import ( + "net/http" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/semconv/internal/v2" + semconv "go.opentelemetry.io/otel/semconv/v1.17.0" +) + +var ( + nc = &internal.NetConv{ + NetHostNameKey: semconv.NetHostNameKey, + NetHostPortKey: semconv.NetHostPortKey, + NetPeerNameKey: semconv.NetPeerNameKey, + NetPeerPortKey: semconv.NetPeerPortKey, + NetSockPeerAddrKey: semconv.NetSockPeerAddrKey, + NetSockPeerPortKey: semconv.NetSockPeerPortKey, + NetTransportOther: semconv.NetTransportOther, + NetTransportTCP: semconv.NetTransportTCP, + NetTransportUDP: semconv.NetTransportUDP, + NetTransportInProc: semconv.NetTransportInProc, + } + + hc = &internal.HTTPConv{ + NetConv: nc, + + EnduserIDKey: semconv.EnduserIDKey, + HTTPClientIPKey: semconv.HTTPClientIPKey, + HTTPFlavorKey: semconv.HTTPFlavorKey, + HTTPMethodKey: semconv.HTTPMethodKey, + HTTPRequestContentLengthKey: semconv.HTTPRequestContentLengthKey, + HTTPResponseContentLengthKey: semconv.HTTPResponseContentLengthKey, + HTTPRouteKey: semconv.HTTPRouteKey, + HTTPSchemeHTTP: semconv.HTTPSchemeHTTP, + HTTPSchemeHTTPS: semconv.HTTPSchemeHTTPS, + HTTPStatusCodeKey: semconv.HTTPStatusCodeKey, + HTTPTargetKey: semconv.HTTPTargetKey, + HTTPURLKey: semconv.HTTPURLKey, + HTTPUserAgentKey: semconv.HTTPUserAgentKey, + } +) + +// ClientResponse returns trace attributes for an HTTP response received by a +// client from a server. It will return the following attributes if the related +// values are defined in resp: "http.status.code", +// "http.response_content_length". +// +// This does not add all OpenTelemetry required attributes for an HTTP event, +// it assumes ClientRequest was used to create the span with a complete set of +// attributes. If a complete set of attributes can be generated using the +// request contained in resp. For example: +// +// append(ClientResponse(resp), ClientRequest(resp.Request)...) +func ClientResponse(resp *http.Response) []attribute.KeyValue { + return hc.ClientResponse(resp) +} + +// ClientRequest returns trace attributes for an HTTP request made by a client. +// The following attributes are always returned: "http.url", "http.flavor", +// "http.method", "net.peer.name". The following attributes are returned if the +// related values are defined in req: "net.peer.port", "http.user_agent", +// "http.request_content_length", "enduser.id". +func ClientRequest(req *http.Request) []attribute.KeyValue { + return hc.ClientRequest(req) +} + +// ClientStatus returns a span status code and message for an HTTP status code +// value received by a client. +func ClientStatus(code int) (codes.Code, string) { + return hc.ClientStatus(code) +} + +// ServerRequest returns trace attributes for an HTTP request received by a +// server. +// +// The server must be the primary server name if it is known. For example this +// would be the ServerName directive +// (https://httpd.apache.org/docs/2.4/mod/core.html#servername) for an Apache +// server, and the server_name directive +// (http://nginx.org/en/docs/http/ngx_http_core_module.html#server_name) for an +// nginx server. More generically, the primary server name would be the host +// header value that matches the default virtual host of an HTTP server. It +// should include the host identifier and if a port is used to route to the +// server that port identifier should be included as an appropriate port +// suffix. +// +// If the primary server name is not known, server should be an empty string. +// The req Host will be used to determine the server instead. +// +// The following attributes are always returned: "http.method", "http.scheme", +// "http.flavor", "http.target", "net.host.name". The following attributes are +// returned if they related values are defined in req: "net.host.port", +// "net.sock.peer.addr", "net.sock.peer.port", "http.user_agent", "enduser.id", +// "http.client_ip". +func ServerRequest(server string, req *http.Request) []attribute.KeyValue { + return hc.ServerRequest(server, req) +} + +// ServerStatus returns a span status code and message for an HTTP status code +// value returned by a server. Status codes in the 400-499 range are not +// returned as errors. +func ServerStatus(code int) (codes.Code, string) { + return hc.ServerStatus(code) +} + +// RequestHeader returns the contents of h as attributes. +// +// Instrumentation should require an explicit configuration of which headers to +// captured and then prune what they pass here. Including all headers can be a +// security risk - explicit configuration helps avoid leaking sensitive +// information. +// +// The User-Agent header is already captured in the http.user_agent attribute +// from ClientRequest and ServerRequest. Instrumentation may provide an option +// to capture that header here even though it is not recommended. Otherwise, +// instrumentation should filter that out of what is passed. +func RequestHeader(h http.Header) []attribute.KeyValue { + return hc.RequestHeader(h) +} + +// ResponseHeader returns the contents of h as attributes. +// +// Instrumentation should require an explicit configuration of which headers to +// captured and then prune what they pass here. Including all headers can be a +// security risk - explicit configuration helps avoid leaking sensitive +// information. +// +// The User-Agent header is already captured in the http.user_agent attribute +// from ClientRequest and ServerRequest. Instrumentation may provide an option +// to capture that header here even though it is not recommended. Otherwise, +// instrumentation should filter that out of what is passed. +func ResponseHeader(h http.Header) []attribute.KeyValue { + return hc.ResponseHeader(h) +} diff --git a/vendor/go.uber.org/multierr/CHANGELOG.md b/vendor/go.uber.org/multierr/CHANGELOG.md new file mode 100644 index 0000000000..d2c8aadaf0 --- /dev/null +++ b/vendor/go.uber.org/multierr/CHANGELOG.md @@ -0,0 +1,80 @@ +Releases +======== + +v1.9.0 (2022-12-12) +=================== + +- Add `AppendFunc` that allow passsing functions to similar to + `AppendInvoke`. + +- Bump up yaml.v3 dependency to 3.0.1. + +v1.8.0 (2022-02-28) +=================== + +- `Combine`: perform zero allocations when there are no errors. + + +v1.7.0 (2021-05-06) +=================== + +- Add `AppendInvoke` to append into errors from `defer` blocks. + + +v1.6.0 (2020-09-14) +=================== + +- Actually drop library dependency on development-time tooling. + + +v1.5.0 (2020-02-24) +=================== + +- Drop library dependency on development-time tooling. + + +v1.4.0 (2019-11-04) +=================== + +- Add `AppendInto` function to more ergonomically build errors inside a + loop. + + +v1.3.0 (2019-10-29) +=================== + +- Switch to Go modules. + + +v1.2.0 (2019-09-26) +=================== + +- Support extracting and matching against wrapped errors with `errors.As` + and `errors.Is`. + + +v1.1.0 (2017-06-30) +=================== + +- Added an `Errors(error) []error` function to extract the underlying list of + errors for a multierr error. + + +v1.0.0 (2017-05-31) +=================== + +No changes since v0.2.0. This release is committing to making no breaking +changes to the current API in the 1.X series. + + +v0.2.0 (2017-04-11) +=================== + +- Repeatedly appending to the same error is now faster due to fewer + allocations. + + +v0.1.0 (2017-31-03) +=================== + +- Initial release diff --git a/vendor/go.uber.org/multierr/README.md b/vendor/go.uber.org/multierr/README.md new file mode 100644 index 0000000000..70aacecd71 --- /dev/null +++ b/vendor/go.uber.org/multierr/README.md @@ -0,0 +1,23 @@ +# multierr [![GoDoc][doc-img]][doc] [![Build Status][ci-img]][ci] [![Coverage Status][cov-img]][cov] + +`multierr` allows combining one or more Go `error`s together. + +## Installation + + go get -u go.uber.org/multierr + +## Status + +Stable: No breaking changes will be made before 2.0. + +------------------------------------------------------------------------------- + +Released under the [MIT License]. + +[MIT License]: LICENSE.txt +[doc-img]: https://pkg.go.dev/badge/go.uber.org/multierr +[doc]: https://pkg.go.dev/go.uber.org/multierr +[ci-img]: https://github.com/uber-go/multierr/actions/workflows/go.yml/badge.svg +[cov-img]: https://codecov.io/gh/uber-go/multierr/branch/master/graph/badge.svg +[ci]: https://github.com/uber-go/multierr/actions/workflows/go.yml +[cov]: https://codecov.io/gh/uber-go/multierr diff --git a/vendor/go.uber.org/multierr/error.go b/vendor/go.uber.org/multierr/error.go new file mode 100644 index 0000000000..cdd91ae56d --- /dev/null +++ b/vendor/go.uber.org/multierr/error.go @@ -0,0 +1,681 @@ +// Copyright (c) 2017-2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Package multierr allows combining one or more errors together. +// +// # Overview +// +// Errors can be combined with the use of the Combine function. +// +// multierr.Combine( +// reader.Close(), +// writer.Close(), +// conn.Close(), +// ) +// +// If only two errors are being combined, the Append function may be used +// instead. +// +// err = multierr.Append(reader.Close(), writer.Close()) +// +// The underlying list of errors for a returned error object may be retrieved +// with the Errors function. +// +// errors := multierr.Errors(err) +// if len(errors) > 0 { +// fmt.Println("The following errors occurred:", errors) +// } +// +// # Appending from a loop +// +// You sometimes need to append into an error from a loop. +// +// var err error +// for _, item := range items { +// err = multierr.Append(err, process(item)) +// } +// +// Cases like this may require knowledge of whether an individual instance +// failed. This usually requires introduction of a new variable. +// +// var err error +// for _, item := range items { +// if perr := process(item); perr != nil { +// log.Warn("skipping item", item) +// err = multierr.Append(err, perr) +// } +// } +// +// multierr includes AppendInto to simplify cases like this. +// +// var err error +// for _, item := range items { +// if multierr.AppendInto(&err, process(item)) { +// log.Warn("skipping item", item) +// } +// } +// +// This will append the error into the err variable, and return true if that +// individual error was non-nil. +// +// See [AppendInto] for more information. +// +// # Deferred Functions +// +// Go makes it possible to modify the return value of a function in a defer +// block if the function was using named returns. This makes it possible to +// record resource cleanup failures from deferred blocks. +// +// func sendRequest(req Request) (err error) { +// conn, err := openConnection() +// if err != nil { +// return err +// } +// defer func() { +// err = multierr.Append(err, conn.Close()) +// }() +// // ... +// } +// +// multierr provides the Invoker type and AppendInvoke function to make cases +// like the above simpler and obviate the need for a closure. The following is +// roughly equivalent to the example above. +// +// func sendRequest(req Request) (err error) { +// conn, err := openConnection() +// if err != nil { +// return err +// } +// defer multierr.AppendInvoke(&err, multierr.Close(conn)) +// // ... +// } +// +// See [AppendInvoke] and [Invoker] for more information. +// +// NOTE: If you're modifying an error from inside a defer, you MUST use a named +// return value for that function. +// +// # Advanced Usage +// +// Errors returned by Combine and Append MAY implement the following +// interface. +// +// type errorGroup interface { +// // Returns a slice containing the underlying list of errors. +// // +// // This slice MUST NOT be modified by the caller. +// Errors() []error +// } +// +// Note that if you need access to list of errors behind a multierr error, you +// should prefer using the Errors function. That said, if you need cheap +// read-only access to the underlying errors slice, you can attempt to cast +// the error to this interface. You MUST handle the failure case gracefully +// because errors returned by Combine and Append are not guaranteed to +// implement this interface. +// +// var errors []error +// group, ok := err.(errorGroup) +// if ok { +// errors = group.Errors() +// } else { +// errors = []error{err} +// } +package multierr // import "go.uber.org/multierr" + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" + "sync" + + "go.uber.org/atomic" +) + +var ( + // Separator for single-line error messages. + _singlelineSeparator = []byte("; ") + + // Prefix for multi-line messages + _multilinePrefix = []byte("the following errors occurred:") + + // Prefix for the first and following lines of an item in a list of + // multi-line error messages. + // + // For example, if a single item is: + // + // foo + // bar + // + // It will become, + // + // - foo + // bar + _multilineSeparator = []byte("\n - ") + _multilineIndent = []byte(" ") +) + +// _bufferPool is a pool of bytes.Buffers. +var _bufferPool = sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, +} + +type errorGroup interface { + Errors() []error +} + +// Errors returns a slice containing zero or more errors that the supplied +// error is composed of. If the error is nil, a nil slice is returned. +// +// err := multierr.Append(r.Close(), w.Close()) +// errors := multierr.Errors(err) +// +// If the error is not composed of other errors, the returned slice contains +// just the error that was passed in. +// +// Callers of this function are free to modify the returned slice. +func Errors(err error) []error { + if err == nil { + return nil + } + + // Note that we're casting to multiError, not errorGroup. Our contract is + // that returned errors MAY implement errorGroup. Errors, however, only + // has special behavior for multierr-specific error objects. + // + // This behavior can be expanded in the future but I think it's prudent to + // start with as little as possible in terms of contract and possibility + // of misuse. + eg, ok := err.(*multiError) + if !ok { + return []error{err} + } + + return append(([]error)(nil), eg.Errors()...) +} + +// multiError is an error that holds one or more errors. +// +// An instance of this is guaranteed to be non-empty and flattened. That is, +// none of the errors inside multiError are other multiErrors. +// +// multiError formats to a semi-colon delimited list of error messages with +// %v and with a more readable multi-line format with %+v. +type multiError struct { + copyNeeded atomic.Bool + errors []error +} + +var _ errorGroup = (*multiError)(nil) + +// Errors returns the list of underlying errors. +// +// This slice MUST NOT be modified. +func (merr *multiError) Errors() []error { + if merr == nil { + return nil + } + return merr.errors +} + +// As attempts to find the first error in the error list that matches the type +// of the value that target points to. +// +// This function allows errors.As to traverse the values stored on the +// multierr error. +func (merr *multiError) As(target interface{}) bool { + for _, err := range merr.Errors() { + if errors.As(err, target) { + return true + } + } + return false +} + +// Is attempts to match the provided error against errors in the error list. +// +// This function allows errors.Is to traverse the values stored on the +// multierr error. +func (merr *multiError) Is(target error) bool { + for _, err := range merr.Errors() { + if errors.Is(err, target) { + return true + } + } + return false +} + +func (merr *multiError) Error() string { + if merr == nil { + return "" + } + + buff := _bufferPool.Get().(*bytes.Buffer) + buff.Reset() + + merr.writeSingleline(buff) + + result := buff.String() + _bufferPool.Put(buff) + return result +} + +func (merr *multiError) Format(f fmt.State, c rune) { + if c == 'v' && f.Flag('+') { + merr.writeMultiline(f) + } else { + merr.writeSingleline(f) + } +} + +func (merr *multiError) writeSingleline(w io.Writer) { + first := true + for _, item := range merr.errors { + if first { + first = false + } else { + w.Write(_singlelineSeparator) + } + io.WriteString(w, item.Error()) + } +} + +func (merr *multiError) writeMultiline(w io.Writer) { + w.Write(_multilinePrefix) + for _, item := range merr.errors { + w.Write(_multilineSeparator) + writePrefixLine(w, _multilineIndent, fmt.Sprintf("%+v", item)) + } +} + +// Writes s to the writer with the given prefix added before each line after +// the first. +func writePrefixLine(w io.Writer, prefix []byte, s string) { + first := true + for len(s) > 0 { + if first { + first = false + } else { + w.Write(prefix) + } + + idx := strings.IndexByte(s, '\n') + if idx < 0 { + idx = len(s) - 1 + } + + io.WriteString(w, s[:idx+1]) + s = s[idx+1:] + } +} + +type inspectResult struct { + // Number of top-level non-nil errors + Count int + + // Total number of errors including multiErrors + Capacity int + + // Index of the first non-nil error in the list. Value is meaningless if + // Count is zero. + FirstErrorIdx int + + // Whether the list contains at least one multiError + ContainsMultiError bool +} + +// Inspects the given slice of errors so that we can efficiently allocate +// space for it. +func inspect(errors []error) (res inspectResult) { + first := true + for i, err := range errors { + if err == nil { + continue + } + + res.Count++ + if first { + first = false + res.FirstErrorIdx = i + } + + if merr, ok := err.(*multiError); ok { + res.Capacity += len(merr.errors) + res.ContainsMultiError = true + } else { + res.Capacity++ + } + } + return +} + +// fromSlice converts the given list of errors into a single error. +func fromSlice(errors []error) error { + // Don't pay to inspect small slices. + switch len(errors) { + case 0: + return nil + case 1: + return errors[0] + } + + res := inspect(errors) + switch res.Count { + case 0: + return nil + case 1: + // only one non-nil entry + return errors[res.FirstErrorIdx] + case len(errors): + if !res.ContainsMultiError { + // Error list is flat. Make a copy of it + // Otherwise "errors" escapes to the heap + // unconditionally for all other cases. + // This lets us optimize for the "no errors" case. + out := append(([]error)(nil), errors...) + return &multiError{errors: out} + } + } + + nonNilErrs := make([]error, 0, res.Capacity) + for _, err := range errors[res.FirstErrorIdx:] { + if err == nil { + continue + } + + if nested, ok := err.(*multiError); ok { + nonNilErrs = append(nonNilErrs, nested.errors...) + } else { + nonNilErrs = append(nonNilErrs, err) + } + } + + return &multiError{errors: nonNilErrs} +} + +// Combine combines the passed errors into a single error. +// +// If zero arguments were passed or if all items are nil, a nil error is +// returned. +// +// Combine(nil, nil) // == nil +// +// If only a single error was passed, it is returned as-is. +// +// Combine(err) // == err +// +// Combine skips over nil arguments so this function may be used to combine +// together errors from operations that fail independently of each other. +// +// multierr.Combine( +// reader.Close(), +// writer.Close(), +// pipe.Close(), +// ) +// +// If any of the passed errors is a multierr error, it will be flattened along +// with the other errors. +// +// multierr.Combine(multierr.Combine(err1, err2), err3) +// // is the same as +// multierr.Combine(err1, err2, err3) +// +// The returned error formats into a readable multi-line error message if +// formatted with %+v. +// +// fmt.Sprintf("%+v", multierr.Combine(err1, err2)) +func Combine(errors ...error) error { + return fromSlice(errors) +} + +// Append appends the given errors together. Either value may be nil. +// +// This function is a specialization of Combine for the common case where +// there are only two errors. +// +// err = multierr.Append(reader.Close(), writer.Close()) +// +// The following pattern may also be used to record failure of deferred +// operations without losing information about the original error. +// +// func doSomething(..) (err error) { +// f := acquireResource() +// defer func() { +// err = multierr.Append(err, f.Close()) +// }() +// +// Note that the variable MUST be a named return to append an error to it from +// the defer statement. See also [AppendInvoke]. +func Append(left error, right error) error { + switch { + case left == nil: + return right + case right == nil: + return left + } + + if _, ok := right.(*multiError); !ok { + if l, ok := left.(*multiError); ok && !l.copyNeeded.Swap(true) { + // Common case where the error on the left is constantly being + // appended to. + errs := append(l.errors, right) + return &multiError{errors: errs} + } else if !ok { + // Both errors are single errors. + return &multiError{errors: []error{left, right}} + } + } + + // Either right or both, left and right, are multiErrors. Rely on usual + // expensive logic. + errors := [2]error{left, right} + return fromSlice(errors[0:]) +} + +// AppendInto appends an error into the destination of an error pointer and +// returns whether the error being appended was non-nil. +// +// var err error +// multierr.AppendInto(&err, r.Close()) +// multierr.AppendInto(&err, w.Close()) +// +// The above is equivalent to, +// +// err := multierr.Append(r.Close(), w.Close()) +// +// As AppendInto reports whether the provided error was non-nil, it may be +// used to build a multierr error in a loop more ergonomically. For example: +// +// var err error +// for line := range lines { +// var item Item +// if multierr.AppendInto(&err, parse(line, &item)) { +// continue +// } +// items = append(items, item) +// } +// +// Compare this with a version that relies solely on Append: +// +// var err error +// for line := range lines { +// var item Item +// if parseErr := parse(line, &item); parseErr != nil { +// err = multierr.Append(err, parseErr) +// continue +// } +// items = append(items, item) +// } +func AppendInto(into *error, err error) (errored bool) { + if into == nil { + // We panic if 'into' is nil. This is not documented above + // because suggesting that the pointer must be non-nil may + // confuse users into thinking that the error that it points + // to must be non-nil. + panic("misuse of multierr.AppendInto: into pointer must not be nil") + } + + if err == nil { + return false + } + *into = Append(*into, err) + return true +} + +// Invoker is an operation that may fail with an error. Use it with +// AppendInvoke to append the result of calling the function into an error. +// This allows you to conveniently defer capture of failing operations. +// +// See also, [Close] and [Invoke]. +type Invoker interface { + Invoke() error +} + +// Invoke wraps a function which may fail with an error to match the Invoker +// interface. Use it to supply functions matching this signature to +// AppendInvoke. +// +// For example, +// +// func processReader(r io.Reader) (err error) { +// scanner := bufio.NewScanner(r) +// defer multierr.AppendInvoke(&err, multierr.Invoke(scanner.Err)) +// for scanner.Scan() { +// // ... +// } +// // ... +// } +// +// In this example, the following line will construct the Invoker right away, +// but defer the invocation of scanner.Err() until the function returns. +// +// defer multierr.AppendInvoke(&err, multierr.Invoke(scanner.Err)) +// +// Note that the error you're appending to from the defer statement MUST be a +// named return. +type Invoke func() error + +// Invoke calls the supplied function and returns its result. +func (i Invoke) Invoke() error { return i() } + +// Close builds an Invoker that closes the provided io.Closer. Use it with +// AppendInvoke to close io.Closers and append their results into an error. +// +// For example, +// +// func processFile(path string) (err error) { +// f, err := os.Open(path) +// if err != nil { +// return err +// } +// defer multierr.AppendInvoke(&err, multierr.Close(f)) +// return processReader(f) +// } +// +// In this example, multierr.Close will construct the Invoker right away, but +// defer the invocation of f.Close until the function returns. +// +// defer multierr.AppendInvoke(&err, multierr.Close(f)) +// +// Note that the error you're appending to from the defer statement MUST be a +// named return. +func Close(closer io.Closer) Invoker { + return Invoke(closer.Close) +} + +// AppendInvoke appends the result of calling the given Invoker into the +// provided error pointer. Use it with named returns to safely defer +// invocation of fallible operations until a function returns, and capture the +// resulting errors. +// +// func doSomething(...) (err error) { +// // ... +// f, err := openFile(..) +// if err != nil { +// return err +// } +// +// // multierr will call f.Close() when this function returns and +// // if the operation fails, its append its error into the +// // returned error. +// defer multierr.AppendInvoke(&err, multierr.Close(f)) +// +// scanner := bufio.NewScanner(f) +// // Similarly, this scheduled scanner.Err to be called and +// // inspected when the function returns and append its error +// // into the returned error. +// defer multierr.AppendInvoke(&err, multierr.Invoke(scanner.Err)) +// +// // ... +// } +// +// NOTE: If used with a defer, the error variable MUST be a named return. +// +// Without defer, AppendInvoke behaves exactly like AppendInto. +// +// err := // ... +// multierr.AppendInvoke(&err, mutltierr.Invoke(foo)) +// +// // ...is roughly equivalent to... +// +// err := // ... +// multierr.AppendInto(&err, foo()) +// +// The advantage of the indirection introduced by Invoker is to make it easy +// to defer the invocation of a function. Without this indirection, the +// invoked function will be evaluated at the time of the defer block rather +// than when the function returns. +// +// // BAD: This is likely not what the caller intended. This will evaluate +// // foo() right away and append its result into the error when the +// // function returns. +// defer multierr.AppendInto(&err, foo()) +// +// // GOOD: This will defer invocation of foo unutil the function returns. +// defer multierr.AppendInvoke(&err, multierr.Invoke(foo)) +// +// multierr provides a few Invoker implementations out of the box for +// convenience. See [Invoker] for more information. +func AppendInvoke(into *error, invoker Invoker) { + AppendInto(into, invoker.Invoke()) +} + +// AppendFunc is a shorthand for [AppendInvoke]. +// It allows using function or method value directly +// without having to wrap it into an [Invoker] interface. +// +// func doSomething(...) (err error) { +// w, err := startWorker(...) +// if err != nil { +// return err +// } +// +// // multierr will call w.Stop() when this function returns and +// // if the operation fails, it appends its error into the +// // returned error. +// defer multierr.AppendFunc(&err, w.Stop) +// } +func AppendFunc(into *error, fn func() error) { + AppendInvoke(into, Invoke(fn)) +} diff --git a/vendor/go.uber.org/multierr/glide.yaml b/vendor/go.uber.org/multierr/glide.yaml new file mode 100644 index 0000000000..6ef084ec24 --- /dev/null +++ b/vendor/go.uber.org/multierr/glide.yaml @@ -0,0 +1,8 @@ +package: go.uber.org/multierr +import: +- package: go.uber.org/atomic + version: ^1 +testImport: +- package: github.com/stretchr/testify + subpackages: + - assert diff --git a/vendor/golang.org/x/arch/x86/x86asm/gnu.go b/vendor/golang.org/x/arch/x86/x86asm/gnu.go new file mode 100644 index 0000000000..75cff72b03 --- /dev/null +++ b/vendor/golang.org/x/arch/x86/x86asm/gnu.go @@ -0,0 +1,956 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package x86asm + +import ( + "fmt" + "strings" +) + +// GNUSyntax returns the GNU assembler syntax for the instruction, as defined by GNU binutils. +// This general form is often called ``AT&T syntax'' as a reference to AT&T System V Unix. +func GNUSyntax(inst Inst, pc uint64, symname SymLookup) string { + // Rewrite instruction to mimic GNU peculiarities. + // Note that inst has been passed by value and contains + // no pointers, so any changes we make here are local + // and will not propagate back out to the caller. + + if symname == nil { + symname = func(uint64) (string, uint64) { return "", 0 } + } + + // Adjust opcode [sic]. + switch inst.Op { + case FDIV, FDIVR, FSUB, FSUBR, FDIVP, FDIVRP, FSUBP, FSUBRP: + // DC E0, DC F0: libopcodes swaps FSUBR/FSUB and FDIVR/FDIV, at least + // if you believe the Intel manual is correct (the encoding is irregular as given; + // libopcodes uses the more regular expected encoding). + // TODO(rsc): Test to ensure Intel manuals are correct and report to libopcodes maintainers? + // NOTE: iant thinks this is deliberate, but we can't find the history. + _, reg1 := inst.Args[0].(Reg) + _, reg2 := inst.Args[1].(Reg) + if reg1 && reg2 && (inst.Opcode>>24 == 0xDC || inst.Opcode>>24 == 0xDE) { + switch inst.Op { + case FDIV: + inst.Op = FDIVR + case FDIVR: + inst.Op = FDIV + case FSUB: + inst.Op = FSUBR + case FSUBR: + inst.Op = FSUB + case FDIVP: + inst.Op = FDIVRP + case FDIVRP: + inst.Op = FDIVP + case FSUBP: + inst.Op = FSUBRP + case FSUBRP: + inst.Op = FSUBP + } + } + + case MOVNTSD: + // MOVNTSD is F2 0F 2B /r. + // MOVNTSS is F3 0F 2B /r (supposedly; not in manuals). + // Usually inner prefixes win for display, + // so that F3 F2 0F 2B 11 is REP MOVNTSD + // and F2 F3 0F 2B 11 is REPN MOVNTSS. + // Libopcodes always prefers MOVNTSS regardless of prefix order. + if countPrefix(&inst, 0xF3) > 0 { + found := false + for i := len(inst.Prefix) - 1; i >= 0; i-- { + switch inst.Prefix[i] & 0xFF { + case 0xF3: + if !found { + found = true + inst.Prefix[i] |= PrefixImplicit + } + case 0xF2: + inst.Prefix[i] &^= PrefixImplicit + } + } + inst.Op = MOVNTSS + } + } + + // Add implicit arguments. + switch inst.Op { + case MONITOR: + inst.Args[0] = EDX + inst.Args[1] = ECX + inst.Args[2] = EAX + if inst.AddrSize == 16 { + inst.Args[2] = AX + } + + case MWAIT: + if inst.Mode == 64 { + inst.Args[0] = RCX + inst.Args[1] = RAX + } else { + inst.Args[0] = ECX + inst.Args[1] = EAX + } + } + + // Adjust which prefixes will be displayed. + // The rule is to display all the prefixes not implied by + // the usual instruction display, that is, all the prefixes + // except the ones with PrefixImplicit set. + // However, of course, there are exceptions to the rule. + switch inst.Op { + case CRC32: + // CRC32 has a mandatory F2 prefix. + // If there are multiple F2s and no F3s, the extra F2s do not print. + // (And Decode has already marked them implicit.) + // However, if there is an F3 anywhere, then the extra F2s do print. + // If there are multiple F2 prefixes *and* an (ignored) F3, + // then libopcodes prints the extra F2s as REPNs. + if countPrefix(&inst, 0xF2) > 1 { + unmarkImplicit(&inst, 0xF2) + markLastImplicit(&inst, 0xF2) + } + + // An unused data size override should probably be shown, + // to distinguish DATA16 CRC32B from plain CRC32B, + // but libopcodes always treats the final override as implicit + // and the others as explicit. + unmarkImplicit(&inst, PrefixDataSize) + markLastImplicit(&inst, PrefixDataSize) + + case CVTSI2SD, CVTSI2SS: + if !isMem(inst.Args[1]) { + markLastImplicit(&inst, PrefixDataSize) + } + + case CVTSD2SI, CVTSS2SI, CVTTSD2SI, CVTTSS2SI, + ENTER, FLDENV, FNSAVE, FNSTENV, FRSTOR, LGDT, LIDT, LRET, + POP, PUSH, RET, SGDT, SIDT, SYSRET, XBEGIN: + markLastImplicit(&inst, PrefixDataSize) + + case LOOP, LOOPE, LOOPNE, MONITOR: + markLastImplicit(&inst, PrefixAddrSize) + + case MOV: + // The 16-bit and 32-bit forms of MOV Sreg, dst and MOV src, Sreg + // cannot be distinguished when src or dst refers to memory, because + // Sreg is always a 16-bit value, even when we're doing a 32-bit + // instruction. Because the instruction tables distinguished these two, + // any operand size prefix has been marked as used (to decide which + // branch to take). Unmark it, so that it will show up in disassembly, + // so that the reader can tell the size of memory operand. + // up with the same arguments + dst, _ := inst.Args[0].(Reg) + src, _ := inst.Args[1].(Reg) + if ES <= src && src <= GS && isMem(inst.Args[0]) || ES <= dst && dst <= GS && isMem(inst.Args[1]) { + unmarkImplicit(&inst, PrefixDataSize) + } + + case MOVDQU: + if countPrefix(&inst, 0xF3) > 1 { + unmarkImplicit(&inst, 0xF3) + markLastImplicit(&inst, 0xF3) + } + + case MOVQ2DQ: + markLastImplicit(&inst, PrefixDataSize) + + case SLDT, SMSW, STR, FXRSTOR, XRSTOR, XSAVE, XSAVEOPT, CMPXCHG8B: + if isMem(inst.Args[0]) { + unmarkImplicit(&inst, PrefixDataSize) + } + + case SYSEXIT: + unmarkImplicit(&inst, PrefixDataSize) + } + + if isCondJmp[inst.Op] || isLoop[inst.Op] || inst.Op == JCXZ || inst.Op == JECXZ || inst.Op == JRCXZ { + if countPrefix(&inst, PrefixCS) > 0 && countPrefix(&inst, PrefixDS) > 0 { + for i, p := range inst.Prefix { + switch p & 0xFFF { + case PrefixPN, PrefixPT: + inst.Prefix[i] &= 0xF0FF // cut interpretation bits, producing original segment prefix + } + } + } + } + + // XACQUIRE/XRELEASE adjustment. + if inst.Op == MOV { + // MOV into memory is a candidate for turning REP into XRELEASE. + // However, if the REP is followed by a REPN, that REPN blocks the + // conversion. + haveREPN := false + for i := len(inst.Prefix) - 1; i >= 0; i-- { + switch inst.Prefix[i] &^ PrefixIgnored { + case PrefixREPN: + haveREPN = true + case PrefixXRELEASE: + if haveREPN { + inst.Prefix[i] = PrefixREP + } + } + } + } + + // We only format the final F2/F3 as XRELEASE/XACQUIRE. + haveXA := false + haveXR := false + for i := len(inst.Prefix) - 1; i >= 0; i-- { + switch inst.Prefix[i] &^ PrefixIgnored { + case PrefixXRELEASE: + if !haveXR { + haveXR = true + } else { + inst.Prefix[i] = PrefixREP + } + + case PrefixXACQUIRE: + if !haveXA { + haveXA = true + } else { + inst.Prefix[i] = PrefixREPN + } + } + } + + // Determine opcode. + op := strings.ToLower(inst.Op.String()) + if alt := gnuOp[inst.Op]; alt != "" { + op = alt + } + + // Determine opcode suffix. + // Libopcodes omits the suffix if the width of the operation + // can be inferred from a register arguments. For example, + // add $1, %ebx has no suffix because you can tell from the + // 32-bit register destination that it is a 32-bit add, + // but in addl $1, (%ebx), the destination is memory, so the + // size is not evident without the l suffix. + needSuffix := true +SuffixLoop: + for i, a := range inst.Args { + if a == nil { + break + } + switch a := a.(type) { + case Reg: + switch inst.Op { + case MOVSX, MOVZX: + continue + + case SHL, SHR, RCL, RCR, ROL, ROR, SAR: + if i == 1 { + // shift count does not tell us operand size + continue + } + + case CRC32: + // The source argument does tell us operand size, + // but libopcodes still always puts a suffix on crc32. + continue + + case PUSH, POP: + // Even though segment registers are 16-bit, push and pop + // can save/restore them from 32-bit slots, so they + // do not imply operand size. + if ES <= a && a <= GS { + continue + } + + case CVTSI2SD, CVTSI2SS: + // The integer register argument takes priority. + if X0 <= a && a <= X15 { + continue + } + } + + if AL <= a && a <= R15 || ES <= a && a <= GS || X0 <= a && a <= X15 || M0 <= a && a <= M7 { + needSuffix = false + break SuffixLoop + } + } + } + + if needSuffix { + switch inst.Op { + case CMPXCHG8B, FLDCW, FNSTCW, FNSTSW, LDMXCSR, LLDT, LMSW, LTR, PCLMULQDQ, + SETA, SETAE, SETB, SETBE, SETE, SETG, SETGE, SETL, SETLE, SETNE, SETNO, SETNP, SETNS, SETO, SETP, SETS, + SLDT, SMSW, STMXCSR, STR, VERR, VERW: + // For various reasons, libopcodes emits no suffix for these instructions. + + case CRC32: + op += byteSizeSuffix(argBytes(&inst, inst.Args[1])) + + case LGDT, LIDT, SGDT, SIDT: + op += byteSizeSuffix(inst.DataSize / 8) + + case MOVZX, MOVSX: + // Integer size conversions get two suffixes. + op = op[:4] + byteSizeSuffix(argBytes(&inst, inst.Args[1])) + byteSizeSuffix(argBytes(&inst, inst.Args[0])) + + case LOOP, LOOPE, LOOPNE: + // Add w suffix to indicate use of CX register instead of ECX. + if inst.AddrSize == 16 { + op += "w" + } + + case CALL, ENTER, JMP, LCALL, LEAVE, LJMP, LRET, RET, SYSRET, XBEGIN: + // Add w suffix to indicate use of 16-bit target. + // Exclude JMP rel8. + if inst.Opcode>>24 == 0xEB { + break + } + if inst.DataSize == 16 && inst.Mode != 16 { + markLastImplicit(&inst, PrefixDataSize) + op += "w" + } else if inst.Mode == 64 { + op += "q" + } + + case FRSTOR, FNSAVE, FNSTENV, FLDENV: + // Add s suffix to indicate shortened FPU state (I guess). + if inst.DataSize == 16 { + op += "s" + } + + case PUSH, POP: + if markLastImplicit(&inst, PrefixDataSize) { + op += byteSizeSuffix(inst.DataSize / 8) + } else if inst.Mode == 64 { + op += "q" + } else { + op += byteSizeSuffix(inst.MemBytes) + } + + default: + if isFloat(inst.Op) { + // I can't explain any of this, but it's what libopcodes does. + switch inst.MemBytes { + default: + if (inst.Op == FLD || inst.Op == FSTP) && isMem(inst.Args[0]) { + op += "t" + } + case 4: + if isFloatInt(inst.Op) { + op += "l" + } else { + op += "s" + } + case 8: + if isFloatInt(inst.Op) { + op += "ll" + } else { + op += "l" + } + } + break + } + + op += byteSizeSuffix(inst.MemBytes) + } + } + + // Adjust special case opcodes. + switch inst.Op { + case 0: + if inst.Prefix[0] != 0 { + return strings.ToLower(inst.Prefix[0].String()) + } + + case INT: + if inst.Opcode>>24 == 0xCC { + inst.Args[0] = nil + op = "int3" + } + + case CMPPS, CMPPD, CMPSD_XMM, CMPSS: + imm, ok := inst.Args[2].(Imm) + if ok && 0 <= imm && imm < 8 { + inst.Args[2] = nil + op = cmppsOps[imm] + op[3:] + } + + case PCLMULQDQ: + imm, ok := inst.Args[2].(Imm) + if ok && imm&^0x11 == 0 { + inst.Args[2] = nil + op = pclmulqOps[(imm&0x10)>>3|(imm&1)] + } + + case XLATB: + if markLastImplicit(&inst, PrefixAddrSize) { + op = "xlat" // not xlatb + } + } + + // Build list of argument strings. + var ( + usedPrefixes bool // segment prefixes consumed by Mem formatting + args []string // formatted arguments + ) + for i, a := range inst.Args { + if a == nil { + break + } + switch inst.Op { + case MOVSB, MOVSW, MOVSD, MOVSQ, OUTSB, OUTSW, OUTSD: + if i == 0 { + usedPrefixes = true // disable use of prefixes for first argument + } else { + usedPrefixes = false + } + } + if a == Imm(1) && (inst.Opcode>>24)&^1 == 0xD0 { + continue + } + args = append(args, gnuArg(&inst, pc, symname, a, &usedPrefixes)) + } + + // The default is to print the arguments in reverse Intel order. + // A few instructions inhibit this behavior. + switch inst.Op { + case BOUND, LCALL, ENTER, LJMP: + // no reverse + default: + // reverse args + for i, j := 0, len(args)-1; i < j; i, j = i+1, j-1 { + args[i], args[j] = args[j], args[i] + } + } + + // Build prefix string. + // Must be after argument formatting, which can turn off segment prefixes. + var ( + prefix = "" // output string + numAddr = 0 + numData = 0 + implicitData = false + ) + for _, p := range inst.Prefix { + if p&0xFF == PrefixDataSize && p&PrefixImplicit != 0 { + implicitData = true + } + } + for _, p := range inst.Prefix { + if p == 0 || p.IsVEX() { + break + } + if p&PrefixImplicit != 0 { + continue + } + switch p &^ (PrefixIgnored | PrefixInvalid) { + default: + if p.IsREX() { + if p&0xFF == PrefixREX { + prefix += "rex " + } else { + prefix += "rex." + p.String()[4:] + " " + } + break + } + prefix += strings.ToLower(p.String()) + " " + + case PrefixPN: + op += ",pn" + continue + + case PrefixPT: + op += ",pt" + continue + + case PrefixAddrSize, PrefixAddr16, PrefixAddr32: + // For unknown reasons, if the addr16 prefix is repeated, + // libopcodes displays all but the last as addr32, even though + // the addressing form used in a memory reference is clearly + // still 16-bit. + n := 32 + if inst.Mode == 32 { + n = 16 + } + numAddr++ + if countPrefix(&inst, PrefixAddrSize) > numAddr { + n = inst.Mode + } + prefix += fmt.Sprintf("addr%d ", n) + continue + + case PrefixData16, PrefixData32: + if implicitData && countPrefix(&inst, PrefixDataSize) > 1 { + // Similar to the addr32 logic above, but it only kicks in + // when something used the data size prefix (one is implicit). + n := 16 + if inst.Mode == 16 { + n = 32 + } + numData++ + if countPrefix(&inst, PrefixDataSize) > numData { + if inst.Mode == 16 { + n = 16 + } else { + n = 32 + } + } + prefix += fmt.Sprintf("data%d ", n) + continue + } + prefix += strings.ToLower(p.String()) + " " + } + } + + // Finally! Put it all together. + text := prefix + op + if args != nil { + text += " " + // Indirect call/jmp gets a star to distinguish from direct jump address. + if (inst.Op == CALL || inst.Op == JMP || inst.Op == LJMP || inst.Op == LCALL) && (isMem(inst.Args[0]) || isReg(inst.Args[0])) { + text += "*" + } + text += strings.Join(args, ",") + } + return text +} + +// gnuArg returns the GNU syntax for the argument x from the instruction inst. +// If *usedPrefixes is false and x is a Mem, then the formatting +// includes any segment prefixes and sets *usedPrefixes to true. +func gnuArg(inst *Inst, pc uint64, symname SymLookup, x Arg, usedPrefixes *bool) string { + if x == nil { + return "" + } + switch x := x.(type) { + case Reg: + switch inst.Op { + case CVTSI2SS, CVTSI2SD, CVTSS2SI, CVTSD2SI, CVTTSD2SI, CVTTSS2SI: + if inst.DataSize == 16 && EAX <= x && x <= R15L { + x -= EAX - AX + } + + case IN, INSB, INSW, INSD, OUT, OUTSB, OUTSW, OUTSD: + // DX is the port, but libopcodes prints it as if it were a memory reference. + if x == DX { + return "(%dx)" + } + case VMOVDQA, VMOVDQU, VMOVNTDQA, VMOVNTDQ: + return strings.Replace(gccRegName[x], "xmm", "ymm", -1) + } + return gccRegName[x] + case Mem: + if s, disp := memArgToSymbol(x, pc, inst.Len, symname); s != "" { + suffix := "" + if disp != 0 { + suffix = fmt.Sprintf("%+d", disp) + } + return fmt.Sprintf("%s%s", s, suffix) + } + seg := "" + var haveCS, haveDS, haveES, haveFS, haveGS, haveSS bool + switch x.Segment { + case CS: + haveCS = true + case DS: + haveDS = true + case ES: + haveES = true + case FS: + haveFS = true + case GS: + haveGS = true + case SS: + haveSS = true + } + switch inst.Op { + case INSB, INSW, INSD, STOSB, STOSW, STOSD, STOSQ, SCASB, SCASW, SCASD, SCASQ: + // These do not accept segment prefixes, at least in the GNU rendering. + default: + if *usedPrefixes { + break + } + for i := len(inst.Prefix) - 1; i >= 0; i-- { + p := inst.Prefix[i] &^ PrefixIgnored + if p == 0 { + continue + } + switch p { + case PrefixCS: + if !haveCS { + haveCS = true + inst.Prefix[i] |= PrefixImplicit + } + case PrefixDS: + if !haveDS { + haveDS = true + inst.Prefix[i] |= PrefixImplicit + } + case PrefixES: + if !haveES { + haveES = true + inst.Prefix[i] |= PrefixImplicit + } + case PrefixFS: + if !haveFS { + haveFS = true + inst.Prefix[i] |= PrefixImplicit + } + case PrefixGS: + if !haveGS { + haveGS = true + inst.Prefix[i] |= PrefixImplicit + } + case PrefixSS: + if !haveSS { + haveSS = true + inst.Prefix[i] |= PrefixImplicit + } + } + } + *usedPrefixes = true + } + if haveCS { + seg += "%cs:" + } + if haveDS { + seg += "%ds:" + } + if haveSS { + seg += "%ss:" + } + if haveES { + seg += "%es:" + } + if haveFS { + seg += "%fs:" + } + if haveGS { + seg += "%gs:" + } + disp := "" + if x.Disp != 0 { + disp = fmt.Sprintf("%#x", x.Disp) + } + if x.Scale == 0 || x.Index == 0 && x.Scale == 1 && (x.Base == ESP || x.Base == RSP || x.Base == 0 && inst.Mode == 64) { + if x.Base == 0 { + return seg + disp + } + return fmt.Sprintf("%s%s(%s)", seg, disp, gccRegName[x.Base]) + } + base := gccRegName[x.Base] + if x.Base == 0 { + base = "" + } + index := gccRegName[x.Index] + if x.Index == 0 { + if inst.AddrSize == 64 { + index = "%riz" + } else { + index = "%eiz" + } + } + if AX <= x.Base && x.Base <= DI { + // 16-bit addressing - no scale + return fmt.Sprintf("%s%s(%s,%s)", seg, disp, base, index) + } + return fmt.Sprintf("%s%s(%s,%s,%d)", seg, disp, base, index, x.Scale) + case Rel: + if pc == 0 { + return fmt.Sprintf(".%+#x", int64(x)) + } else { + addr := pc + uint64(inst.Len) + uint64(x) + if s, base := symname(addr); s != "" && addr == base { + return fmt.Sprintf("%s", s) + } else { + addr := pc + uint64(inst.Len) + uint64(x) + return fmt.Sprintf("%#x", addr) + } + } + case Imm: + if s, base := symname(uint64(x)); s != "" { + suffix := "" + if uint64(x) != base { + suffix = fmt.Sprintf("%+d", uint64(x)-base) + } + return fmt.Sprintf("$%s%s", s, suffix) + } + if inst.Mode == 32 { + return fmt.Sprintf("$%#x", uint32(x)) + } + return fmt.Sprintf("$%#x", int64(x)) + } + return x.String() +} + +var gccRegName = [...]string{ + 0: "REG0", + AL: "%al", + CL: "%cl", + BL: "%bl", + DL: "%dl", + AH: "%ah", + CH: "%ch", + BH: "%bh", + DH: "%dh", + SPB: "%spl", + BPB: "%bpl", + SIB: "%sil", + DIB: "%dil", + R8B: "%r8b", + R9B: "%r9b", + R10B: "%r10b", + R11B: "%r11b", + R12B: "%r12b", + R13B: "%r13b", + R14B: "%r14b", + R15B: "%r15b", + AX: "%ax", + CX: "%cx", + BX: "%bx", + DX: "%dx", + SP: "%sp", + BP: "%bp", + SI: "%si", + DI: "%di", + R8W: "%r8w", + R9W: "%r9w", + R10W: "%r10w", + R11W: "%r11w", + R12W: "%r12w", + R13W: "%r13w", + R14W: "%r14w", + R15W: "%r15w", + EAX: "%eax", + ECX: "%ecx", + EDX: "%edx", + EBX: "%ebx", + ESP: "%esp", + EBP: "%ebp", + ESI: "%esi", + EDI: "%edi", + R8L: "%r8d", + R9L: "%r9d", + R10L: "%r10d", + R11L: "%r11d", + R12L: "%r12d", + R13L: "%r13d", + R14L: "%r14d", + R15L: "%r15d", + RAX: "%rax", + RCX: "%rcx", + RDX: "%rdx", + RBX: "%rbx", + RSP: "%rsp", + RBP: "%rbp", + RSI: "%rsi", + RDI: "%rdi", + R8: "%r8", + R9: "%r9", + R10: "%r10", + R11: "%r11", + R12: "%r12", + R13: "%r13", + R14: "%r14", + R15: "%r15", + IP: "%ip", + EIP: "%eip", + RIP: "%rip", + F0: "%st", + F1: "%st(1)", + F2: "%st(2)", + F3: "%st(3)", + F4: "%st(4)", + F5: "%st(5)", + F6: "%st(6)", + F7: "%st(7)", + M0: "%mm0", + M1: "%mm1", + M2: "%mm2", + M3: "%mm3", + M4: "%mm4", + M5: "%mm5", + M6: "%mm6", + M7: "%mm7", + X0: "%xmm0", + X1: "%xmm1", + X2: "%xmm2", + X3: "%xmm3", + X4: "%xmm4", + X5: "%xmm5", + X6: "%xmm6", + X7: "%xmm7", + X8: "%xmm8", + X9: "%xmm9", + X10: "%xmm10", + X11: "%xmm11", + X12: "%xmm12", + X13: "%xmm13", + X14: "%xmm14", + X15: "%xmm15", + CS: "%cs", + SS: "%ss", + DS: "%ds", + ES: "%es", + FS: "%fs", + GS: "%gs", + GDTR: "%gdtr", + IDTR: "%idtr", + LDTR: "%ldtr", + MSW: "%msw", + TASK: "%task", + CR0: "%cr0", + CR1: "%cr1", + CR2: "%cr2", + CR3: "%cr3", + CR4: "%cr4", + CR5: "%cr5", + CR6: "%cr6", + CR7: "%cr7", + CR8: "%cr8", + CR9: "%cr9", + CR10: "%cr10", + CR11: "%cr11", + CR12: "%cr12", + CR13: "%cr13", + CR14: "%cr14", + CR15: "%cr15", + DR0: "%db0", + DR1: "%db1", + DR2: "%db2", + DR3: "%db3", + DR4: "%db4", + DR5: "%db5", + DR6: "%db6", + DR7: "%db7", + TR0: "%tr0", + TR1: "%tr1", + TR2: "%tr2", + TR3: "%tr3", + TR4: "%tr4", + TR5: "%tr5", + TR6: "%tr6", + TR7: "%tr7", +} + +var gnuOp = map[Op]string{ + CBW: "cbtw", + CDQ: "cltd", + CMPSD: "cmpsl", + CMPSD_XMM: "cmpsd", + CWD: "cwtd", + CWDE: "cwtl", + CQO: "cqto", + INSD: "insl", + IRET: "iretw", + IRETD: "iret", + IRETQ: "iretq", + LODSB: "lods", + LODSD: "lods", + LODSQ: "lods", + LODSW: "lods", + MOVSD: "movsl", + MOVSD_XMM: "movsd", + OUTSD: "outsl", + POPA: "popaw", + POPAD: "popa", + POPF: "popfw", + POPFD: "popf", + PUSHA: "pushaw", + PUSHAD: "pusha", + PUSHF: "pushfw", + PUSHFD: "pushf", + SCASB: "scas", + SCASD: "scas", + SCASQ: "scas", + SCASW: "scas", + STOSB: "stos", + STOSD: "stos", + STOSQ: "stos", + STOSW: "stos", + XLATB: "xlat", +} + +var cmppsOps = []string{ + "cmpeq", + "cmplt", + "cmple", + "cmpunord", + "cmpneq", + "cmpnlt", + "cmpnle", + "cmpord", +} + +var pclmulqOps = []string{ + "pclmullqlqdq", + "pclmulhqlqdq", + "pclmullqhqdq", + "pclmulhqhqdq", +} + +func countPrefix(inst *Inst, target Prefix) int { + n := 0 + for _, p := range inst.Prefix { + if p&0xFF == target&0xFF { + n++ + } + } + return n +} + +func markLastImplicit(inst *Inst, prefix Prefix) bool { + for i := len(inst.Prefix) - 1; i >= 0; i-- { + p := inst.Prefix[i] + if p&0xFF == prefix { + inst.Prefix[i] |= PrefixImplicit + return true + } + } + return false +} + +func unmarkImplicit(inst *Inst, prefix Prefix) { + for i := len(inst.Prefix) - 1; i >= 0; i-- { + p := inst.Prefix[i] + if p&0xFF == prefix { + inst.Prefix[i] &^= PrefixImplicit + } + } +} + +func byteSizeSuffix(b int) string { + switch b { + case 1: + return "b" + case 2: + return "w" + case 4: + return "l" + case 8: + return "q" + } + return "" +} + +func argBytes(inst *Inst, arg Arg) int { + if isMem(arg) { + return inst.MemBytes + } + return regBytes(arg) +} + +func isFloat(op Op) bool { + switch op { + case FADD, FCOM, FCOMP, FDIV, FDIVR, FIADD, FICOM, FICOMP, FIDIV, FIDIVR, FILD, FIMUL, FIST, FISTP, FISTTP, FISUB, FISUBR, FLD, FMUL, FST, FSTP, FSUB, FSUBR: + return true + } + return false +} + +func isFloatInt(op Op) bool { + switch op { + case FIADD, FICOM, FICOMP, FIDIV, FIDIVR, FILD, FIMUL, FIST, FISTP, FISTTP, FISUB, FISUBR: + return true + } + return false +} diff --git a/vendor/golang.org/x/arch/x86/x86asm/inst.go b/vendor/golang.org/x/arch/x86/x86asm/inst.go new file mode 100644 index 0000000000..4632b5064f --- /dev/null +++ b/vendor/golang.org/x/arch/x86/x86asm/inst.go @@ -0,0 +1,649 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package x86asm implements decoding of x86 machine code. +package x86asm + +import ( + "bytes" + "fmt" +) + +// An Inst is a single instruction. +type Inst struct { + Prefix Prefixes // Prefixes applied to the instruction. + Op Op // Opcode mnemonic + Opcode uint32 // Encoded opcode bits, left aligned (first byte is Opcode>>24, etc) + Args Args // Instruction arguments, in Intel order + Mode int // processor mode in bits: 16, 32, or 64 + AddrSize int // address size in bits: 16, 32, or 64 + DataSize int // operand size in bits: 16, 32, or 64 + MemBytes int // size of memory argument in bytes: 1, 2, 4, 8, 16, and so on. + Len int // length of encoded instruction in bytes + PCRel int // length of PC-relative address in instruction encoding + PCRelOff int // index of start of PC-relative address in instruction encoding +} + +// Prefixes is an array of prefixes associated with a single instruction. +// The prefixes are listed in the same order as found in the instruction: +// each prefix byte corresponds to one slot in the array. The first zero +// in the array marks the end of the prefixes. +type Prefixes [14]Prefix + +// A Prefix represents an Intel instruction prefix. +// The low 8 bits are the actual prefix byte encoding, +// and the top 8 bits contain distinguishing bits and metadata. +type Prefix uint16 + +const ( + // Metadata about the role of a prefix in an instruction. + PrefixImplicit Prefix = 0x8000 // prefix is implied by instruction text + PrefixIgnored Prefix = 0x4000 // prefix is ignored: either irrelevant or overridden by a later prefix + PrefixInvalid Prefix = 0x2000 // prefix makes entire instruction invalid (bad LOCK) + + // Memory segment overrides. + PrefixES Prefix = 0x26 // ES segment override + PrefixCS Prefix = 0x2E // CS segment override + PrefixSS Prefix = 0x36 // SS segment override + PrefixDS Prefix = 0x3E // DS segment override + PrefixFS Prefix = 0x64 // FS segment override + PrefixGS Prefix = 0x65 // GS segment override + + // Branch prediction. + PrefixPN Prefix = 0x12E // predict not taken (conditional branch only) + PrefixPT Prefix = 0x13E // predict taken (conditional branch only) + + // Size attributes. + PrefixDataSize Prefix = 0x66 // operand size override + PrefixData16 Prefix = 0x166 + PrefixData32 Prefix = 0x266 + PrefixAddrSize Prefix = 0x67 // address size override + PrefixAddr16 Prefix = 0x167 + PrefixAddr32 Prefix = 0x267 + + // One of a kind. + PrefixLOCK Prefix = 0xF0 // lock + PrefixREPN Prefix = 0xF2 // repeat not zero + PrefixXACQUIRE Prefix = 0x1F2 + PrefixBND Prefix = 0x2F2 + PrefixREP Prefix = 0xF3 // repeat + PrefixXRELEASE Prefix = 0x1F3 + + // The REX prefixes must be in the range [PrefixREX, PrefixREX+0x10). + // the other bits are set or not according to the intended use. + PrefixREX Prefix = 0x40 // REX 64-bit extension prefix + PrefixREXW Prefix = 0x08 // extension bit W (64-bit instruction width) + PrefixREXR Prefix = 0x04 // extension bit R (r field in modrm) + PrefixREXX Prefix = 0x02 // extension bit X (index field in sib) + PrefixREXB Prefix = 0x01 // extension bit B (r/m field in modrm or base field in sib) + PrefixVEX2Bytes Prefix = 0xC5 // Short form of vex prefix + PrefixVEX3Bytes Prefix = 0xC4 // Long form of vex prefix +) + +// IsREX reports whether p is a REX prefix byte. +func (p Prefix) IsREX() bool { + return p&0xF0 == PrefixREX +} + +func (p Prefix) IsVEX() bool { + return p&0xFF == PrefixVEX2Bytes || p&0xFF == PrefixVEX3Bytes +} + +func (p Prefix) String() string { + p &^= PrefixImplicit | PrefixIgnored | PrefixInvalid + if s := prefixNames[p]; s != "" { + return s + } + + if p.IsREX() { + s := "REX." + if p&PrefixREXW != 0 { + s += "W" + } + if p&PrefixREXR != 0 { + s += "R" + } + if p&PrefixREXX != 0 { + s += "X" + } + if p&PrefixREXB != 0 { + s += "B" + } + return s + } + + return fmt.Sprintf("Prefix(%#x)", int(p)) +} + +// An Op is an x86 opcode. +type Op uint32 + +func (op Op) String() string { + i := int(op) + if i < 0 || i >= len(opNames) || opNames[i] == "" { + return fmt.Sprintf("Op(%d)", i) + } + return opNames[i] +} + +// An Args holds the instruction arguments. +// If an instruction has fewer than 4 arguments, +// the final elements in the array are nil. +type Args [4]Arg + +// An Arg is a single instruction argument, +// one of these types: Reg, Mem, Imm, Rel. +type Arg interface { + String() string + isArg() +} + +// Note that the implements of Arg that follow are all sized +// so that on a 64-bit machine the data can be inlined in +// the interface value instead of requiring an allocation. + +// A Reg is a single register. +// The zero Reg value has no name but indicates ``no register.'' +type Reg uint8 + +const ( + _ Reg = iota + + // 8-bit + AL + CL + DL + BL + AH + CH + DH + BH + SPB + BPB + SIB + DIB + R8B + R9B + R10B + R11B + R12B + R13B + R14B + R15B + + // 16-bit + AX + CX + DX + BX + SP + BP + SI + DI + R8W + R9W + R10W + R11W + R12W + R13W + R14W + R15W + + // 32-bit + EAX + ECX + EDX + EBX + ESP + EBP + ESI + EDI + R8L + R9L + R10L + R11L + R12L + R13L + R14L + R15L + + // 64-bit + RAX + RCX + RDX + RBX + RSP + RBP + RSI + RDI + R8 + R9 + R10 + R11 + R12 + R13 + R14 + R15 + + // Instruction pointer. + IP // 16-bit + EIP // 32-bit + RIP // 64-bit + + // 387 floating point registers. + F0 + F1 + F2 + F3 + F4 + F5 + F6 + F7 + + // MMX registers. + M0 + M1 + M2 + M3 + M4 + M5 + M6 + M7 + + // XMM registers. + X0 + X1 + X2 + X3 + X4 + X5 + X6 + X7 + X8 + X9 + X10 + X11 + X12 + X13 + X14 + X15 + + // Segment registers. + ES + CS + SS + DS + FS + GS + + // System registers. + GDTR + IDTR + LDTR + MSW + TASK + + // Control registers. + CR0 + CR1 + CR2 + CR3 + CR4 + CR5 + CR6 + CR7 + CR8 + CR9 + CR10 + CR11 + CR12 + CR13 + CR14 + CR15 + + // Debug registers. + DR0 + DR1 + DR2 + DR3 + DR4 + DR5 + DR6 + DR7 + DR8 + DR9 + DR10 + DR11 + DR12 + DR13 + DR14 + DR15 + + // Task registers. + TR0 + TR1 + TR2 + TR3 + TR4 + TR5 + TR6 + TR7 +) + +const regMax = TR7 + +func (Reg) isArg() {} + +func (r Reg) String() string { + i := int(r) + if i < 0 || i >= len(regNames) || regNames[i] == "" { + return fmt.Sprintf("Reg(%d)", i) + } + return regNames[i] +} + +// A Mem is a memory reference. +// The general form is Segment:[Base+Scale*Index+Disp]. +type Mem struct { + Segment Reg + Base Reg + Scale uint8 + Index Reg + Disp int64 +} + +func (Mem) isArg() {} + +func (m Mem) String() string { + var base, plus, scale, index, disp string + + if m.Base != 0 { + base = m.Base.String() + } + if m.Scale != 0 { + if m.Base != 0 { + plus = "+" + } + if m.Scale > 1 { + scale = fmt.Sprintf("%d*", m.Scale) + } + index = m.Index.String() + } + if m.Disp != 0 || m.Base == 0 && m.Scale == 0 { + disp = fmt.Sprintf("%+#x", m.Disp) + } + return "[" + base + plus + scale + index + disp + "]" +} + +// A Rel is an offset relative to the current instruction pointer. +type Rel int32 + +func (Rel) isArg() {} + +func (r Rel) String() string { + return fmt.Sprintf(".%+d", r) +} + +// An Imm is an integer constant. +type Imm int64 + +func (Imm) isArg() {} + +func (i Imm) String() string { + return fmt.Sprintf("%#x", int64(i)) +} + +func (i Inst) String() string { + var buf bytes.Buffer + for _, p := range i.Prefix { + if p == 0 { + break + } + if p&PrefixImplicit != 0 { + continue + } + fmt.Fprintf(&buf, "%v ", p) + } + fmt.Fprintf(&buf, "%v", i.Op) + sep := " " + for _, v := range i.Args { + if v == nil { + break + } + fmt.Fprintf(&buf, "%s%v", sep, v) + sep = ", " + } + return buf.String() +} + +func isReg(a Arg) bool { + _, ok := a.(Reg) + return ok +} + +func isSegReg(a Arg) bool { + r, ok := a.(Reg) + return ok && ES <= r && r <= GS +} + +func isMem(a Arg) bool { + _, ok := a.(Mem) + return ok +} + +func isImm(a Arg) bool { + _, ok := a.(Imm) + return ok +} + +func regBytes(a Arg) int { + r, ok := a.(Reg) + if !ok { + return 0 + } + if AL <= r && r <= R15B { + return 1 + } + if AX <= r && r <= R15W { + return 2 + } + if EAX <= r && r <= R15L { + return 4 + } + if RAX <= r && r <= R15 { + return 8 + } + return 0 +} + +func isSegment(p Prefix) bool { + switch p { + case PrefixCS, PrefixDS, PrefixES, PrefixFS, PrefixGS, PrefixSS: + return true + } + return false +} + +// The Op definitions and string list are in tables.go. + +var prefixNames = map[Prefix]string{ + PrefixCS: "CS", + PrefixDS: "DS", + PrefixES: "ES", + PrefixFS: "FS", + PrefixGS: "GS", + PrefixSS: "SS", + PrefixLOCK: "LOCK", + PrefixREP: "REP", + PrefixREPN: "REPN", + PrefixAddrSize: "ADDRSIZE", + PrefixDataSize: "DATASIZE", + PrefixAddr16: "ADDR16", + PrefixData16: "DATA16", + PrefixAddr32: "ADDR32", + PrefixData32: "DATA32", + PrefixBND: "BND", + PrefixXACQUIRE: "XACQUIRE", + PrefixXRELEASE: "XRELEASE", + PrefixREX: "REX", + PrefixPT: "PT", + PrefixPN: "PN", +} + +var regNames = [...]string{ + AL: "AL", + CL: "CL", + BL: "BL", + DL: "DL", + AH: "AH", + CH: "CH", + BH: "BH", + DH: "DH", + SPB: "SPB", + BPB: "BPB", + SIB: "SIB", + DIB: "DIB", + R8B: "R8B", + R9B: "R9B", + R10B: "R10B", + R11B: "R11B", + R12B: "R12B", + R13B: "R13B", + R14B: "R14B", + R15B: "R15B", + AX: "AX", + CX: "CX", + BX: "BX", + DX: "DX", + SP: "SP", + BP: "BP", + SI: "SI", + DI: "DI", + R8W: "R8W", + R9W: "R9W", + R10W: "R10W", + R11W: "R11W", + R12W: "R12W", + R13W: "R13W", + R14W: "R14W", + R15W: "R15W", + EAX: "EAX", + ECX: "ECX", + EDX: "EDX", + EBX: "EBX", + ESP: "ESP", + EBP: "EBP", + ESI: "ESI", + EDI: "EDI", + R8L: "R8L", + R9L: "R9L", + R10L: "R10L", + R11L: "R11L", + R12L: "R12L", + R13L: "R13L", + R14L: "R14L", + R15L: "R15L", + RAX: "RAX", + RCX: "RCX", + RDX: "RDX", + RBX: "RBX", + RSP: "RSP", + RBP: "RBP", + RSI: "RSI", + RDI: "RDI", + R8: "R8", + R9: "R9", + R10: "R10", + R11: "R11", + R12: "R12", + R13: "R13", + R14: "R14", + R15: "R15", + IP: "IP", + EIP: "EIP", + RIP: "RIP", + F0: "F0", + F1: "F1", + F2: "F2", + F3: "F3", + F4: "F4", + F5: "F5", + F6: "F6", + F7: "F7", + M0: "M0", + M1: "M1", + M2: "M2", + M3: "M3", + M4: "M4", + M5: "M5", + M6: "M6", + M7: "M7", + X0: "X0", + X1: "X1", + X2: "X2", + X3: "X3", + X4: "X4", + X5: "X5", + X6: "X6", + X7: "X7", + X8: "X8", + X9: "X9", + X10: "X10", + X11: "X11", + X12: "X12", + X13: "X13", + X14: "X14", + X15: "X15", + CS: "CS", + SS: "SS", + DS: "DS", + ES: "ES", + FS: "FS", + GS: "GS", + GDTR: "GDTR", + IDTR: "IDTR", + LDTR: "LDTR", + MSW: "MSW", + TASK: "TASK", + CR0: "CR0", + CR1: "CR1", + CR2: "CR2", + CR3: "CR3", + CR4: "CR4", + CR5: "CR5", + CR6: "CR6", + CR7: "CR7", + CR8: "CR8", + CR9: "CR9", + CR10: "CR10", + CR11: "CR11", + CR12: "CR12", + CR13: "CR13", + CR14: "CR14", + CR15: "CR15", + DR0: "DR0", + DR1: "DR1", + DR2: "DR2", + DR3: "DR3", + DR4: "DR4", + DR5: "DR5", + DR6: "DR6", + DR7: "DR7", + DR8: "DR8", + DR9: "DR9", + DR10: "DR10", + DR11: "DR11", + DR12: "DR12", + DR13: "DR13", + DR14: "DR14", + DR15: "DR15", + TR0: "TR0", + TR1: "TR1", + TR2: "TR2", + TR3: "TR3", + TR4: "TR4", + TR5: "TR5", + TR6: "TR6", + TR7: "TR7", +} diff --git a/vendor/golang.org/x/sync/singleflight/singleflight.go b/vendor/golang.org/x/sync/singleflight/singleflight.go new file mode 100644 index 0000000000..8473fb7922 --- /dev/null +++ b/vendor/golang.org/x/sync/singleflight/singleflight.go @@ -0,0 +1,205 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight // import "golang.org/x/sync/singleflight" + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v interface{}) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val interface{} + err error + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result struct { + Val interface{} + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result { + ch := make(chan Result, 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call{chans: []chan<- Result{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + g.mu.Lock() + defer g.mu.Unlock() + c.wg.Done() + if g.m[key] == c { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group) Forget(key string) { + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() +} diff --git a/vendor/golang.org/x/sys/unix/syscall_linux.go b/vendor/golang.org/x/sys/unix/syscall_linux.go new file mode 100644 index 0000000000..fb4e50224c --- /dev/null +++ b/vendor/golang.org/x/sys/unix/syscall_linux.go @@ -0,0 +1,2484 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Linux system calls. +// This file is compiled as ordinary Go code, +// but it is also input to mksyscall, +// which parses the //sys lines and generates system call stubs. +// Note that sometimes we use a lowercase //sys name and +// wrap it in our own nicer implementation. + +package unix + +import ( + "encoding/binary" + "strconv" + "syscall" + "time" + "unsafe" +) + +/* + * Wrapped + */ + +func Access(path string, mode uint32) (err error) { + return Faccessat(AT_FDCWD, path, mode, 0) +} + +func Chmod(path string, mode uint32) (err error) { + return Fchmodat(AT_FDCWD, path, mode, 0) +} + +func Chown(path string, uid int, gid int) (err error) { + return Fchownat(AT_FDCWD, path, uid, gid, 0) +} + +func Creat(path string, mode uint32) (fd int, err error) { + return Open(path, O_CREAT|O_WRONLY|O_TRUNC, mode) +} + +func EpollCreate(size int) (fd int, err error) { + if size <= 0 { + return -1, EINVAL + } + return EpollCreate1(0) +} + +//sys FanotifyInit(flags uint, event_f_flags uint) (fd int, err error) +//sys fanotifyMark(fd int, flags uint, mask uint64, dirFd int, pathname *byte) (err error) + +func FanotifyMark(fd int, flags uint, mask uint64, dirFd int, pathname string) (err error) { + if pathname == "" { + return fanotifyMark(fd, flags, mask, dirFd, nil) + } + p, err := BytePtrFromString(pathname) + if err != nil { + return err + } + return fanotifyMark(fd, flags, mask, dirFd, p) +} + +//sys fchmodat(dirfd int, path string, mode uint32) (err error) + +func Fchmodat(dirfd int, path string, mode uint32, flags int) (err error) { + // Linux fchmodat doesn't support the flags parameter. Mimick glibc's behavior + // and check the flags. Otherwise the mode would be applied to the symlink + // destination which is not what the user expects. + if flags&^AT_SYMLINK_NOFOLLOW != 0 { + return EINVAL + } else if flags&AT_SYMLINK_NOFOLLOW != 0 { + return EOPNOTSUPP + } + return fchmodat(dirfd, path, mode) +} + +func InotifyInit() (fd int, err error) { + return InotifyInit1(0) +} + +//sys ioctl(fd int, req uint, arg uintptr) (err error) = SYS_IOCTL +//sys ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error) = SYS_IOCTL + +// ioctl itself should not be exposed directly, but additional get/set functions +// for specific types are permissible. These are defined in ioctl.go and +// ioctl_linux.go. +// +// The third argument to ioctl is often a pointer but sometimes an integer. +// Callers should use ioctlPtr when the third argument is a pointer and ioctl +// when the third argument is an integer. +// +// TODO: some existing code incorrectly uses ioctl when it should use ioctlPtr. + +//sys Linkat(olddirfd int, oldpath string, newdirfd int, newpath string, flags int) (err error) + +func Link(oldpath string, newpath string) (err error) { + return Linkat(AT_FDCWD, oldpath, AT_FDCWD, newpath, 0) +} + +func Mkdir(path string, mode uint32) (err error) { + return Mkdirat(AT_FDCWD, path, mode) +} + +func Mknod(path string, mode uint32, dev int) (err error) { + return Mknodat(AT_FDCWD, path, mode, dev) +} + +func Open(path string, mode int, perm uint32) (fd int, err error) { + return openat(AT_FDCWD, path, mode|O_LARGEFILE, perm) +} + +//sys openat(dirfd int, path string, flags int, mode uint32) (fd int, err error) + +func Openat(dirfd int, path string, flags int, mode uint32) (fd int, err error) { + return openat(dirfd, path, flags|O_LARGEFILE, mode) +} + +//sys openat2(dirfd int, path string, open_how *OpenHow, size int) (fd int, err error) + +func Openat2(dirfd int, path string, how *OpenHow) (fd int, err error) { + return openat2(dirfd, path, how, SizeofOpenHow) +} + +func Pipe(p []int) error { + return Pipe2(p, 0) +} + +//sysnb pipe2(p *[2]_C_int, flags int) (err error) + +func Pipe2(p []int, flags int) error { + if len(p) != 2 { + return EINVAL + } + var pp [2]_C_int + err := pipe2(&pp, flags) + if err == nil { + p[0] = int(pp[0]) + p[1] = int(pp[1]) + } + return err +} + +//sys ppoll(fds *PollFd, nfds int, timeout *Timespec, sigmask *Sigset_t) (n int, err error) + +func Ppoll(fds []PollFd, timeout *Timespec, sigmask *Sigset_t) (n int, err error) { + if len(fds) == 0 { + return ppoll(nil, 0, timeout, sigmask) + } + return ppoll(&fds[0], len(fds), timeout, sigmask) +} + +func Poll(fds []PollFd, timeout int) (n int, err error) { + var ts *Timespec + if timeout >= 0 { + ts = new(Timespec) + *ts = NsecToTimespec(int64(timeout) * 1e6) + } + return Ppoll(fds, ts, nil) +} + +//sys Readlinkat(dirfd int, path string, buf []byte) (n int, err error) + +func Readlink(path string, buf []byte) (n int, err error) { + return Readlinkat(AT_FDCWD, path, buf) +} + +func Rename(oldpath string, newpath string) (err error) { + return Renameat(AT_FDCWD, oldpath, AT_FDCWD, newpath) +} + +func Rmdir(path string) error { + return Unlinkat(AT_FDCWD, path, AT_REMOVEDIR) +} + +//sys Symlinkat(oldpath string, newdirfd int, newpath string) (err error) + +func Symlink(oldpath string, newpath string) (err error) { + return Symlinkat(oldpath, AT_FDCWD, newpath) +} + +func Unlink(path string) error { + return Unlinkat(AT_FDCWD, path, 0) +} + +//sys Unlinkat(dirfd int, path string, flags int) (err error) + +func Utimes(path string, tv []Timeval) error { + if tv == nil { + err := utimensat(AT_FDCWD, path, nil, 0) + if err != ENOSYS { + return err + } + return utimes(path, nil) + } + if len(tv) != 2 { + return EINVAL + } + var ts [2]Timespec + ts[0] = NsecToTimespec(TimevalToNsec(tv[0])) + ts[1] = NsecToTimespec(TimevalToNsec(tv[1])) + err := utimensat(AT_FDCWD, path, (*[2]Timespec)(unsafe.Pointer(&ts[0])), 0) + if err != ENOSYS { + return err + } + return utimes(path, (*[2]Timeval)(unsafe.Pointer(&tv[0]))) +} + +//sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error) + +func UtimesNano(path string, ts []Timespec) error { + return UtimesNanoAt(AT_FDCWD, path, ts, 0) +} + +func UtimesNanoAt(dirfd int, path string, ts []Timespec, flags int) error { + if ts == nil { + return utimensat(dirfd, path, nil, flags) + } + if len(ts) != 2 { + return EINVAL + } + return utimensat(dirfd, path, (*[2]Timespec)(unsafe.Pointer(&ts[0])), flags) +} + +func Futimesat(dirfd int, path string, tv []Timeval) error { + if tv == nil { + return futimesat(dirfd, path, nil) + } + if len(tv) != 2 { + return EINVAL + } + return futimesat(dirfd, path, (*[2]Timeval)(unsafe.Pointer(&tv[0]))) +} + +func Futimes(fd int, tv []Timeval) (err error) { + // Believe it or not, this is the best we can do on Linux + // (and is what glibc does). + return Utimes("/proc/self/fd/"+strconv.Itoa(fd), tv) +} + +const ImplementsGetwd = true + +//sys Getcwd(buf []byte) (n int, err error) + +func Getwd() (wd string, err error) { + var buf [PathMax]byte + n, err := Getcwd(buf[0:]) + if err != nil { + return "", err + } + // Getcwd returns the number of bytes written to buf, including the NUL. + if n < 1 || n > len(buf) || buf[n-1] != 0 { + return "", EINVAL + } + // In some cases, Linux can return a path that starts with the + // "(unreachable)" prefix, which can potentially be a valid relative + // path. To work around that, return ENOENT if path is not absolute. + if buf[0] != '/' { + return "", ENOENT + } + + return string(buf[0 : n-1]), nil +} + +func Getgroups() (gids []int, err error) { + n, err := getgroups(0, nil) + if err != nil { + return nil, err + } + if n == 0 { + return nil, nil + } + + // Sanity check group count. Max is 1<<16 on Linux. + if n < 0 || n > 1<<20 { + return nil, EINVAL + } + + a := make([]_Gid_t, n) + n, err = getgroups(n, &a[0]) + if err != nil { + return nil, err + } + gids = make([]int, n) + for i, v := range a[0:n] { + gids[i] = int(v) + } + return +} + +func Setgroups(gids []int) (err error) { + if len(gids) == 0 { + return setgroups(0, nil) + } + + a := make([]_Gid_t, len(gids)) + for i, v := range gids { + a[i] = _Gid_t(v) + } + return setgroups(len(a), &a[0]) +} + +type WaitStatus uint32 + +// Wait status is 7 bits at bottom, either 0 (exited), +// 0x7F (stopped), or a signal number that caused an exit. +// The 0x80 bit is whether there was a core dump. +// An extra number (exit code, signal causing a stop) +// is in the high bits. At least that's the idea. +// There are various irregularities. For example, the +// "continued" status is 0xFFFF, distinguishing itself +// from stopped via the core dump bit. + +const ( + mask = 0x7F + core = 0x80 + exited = 0x00 + stopped = 0x7F + shift = 8 +) + +func (w WaitStatus) Exited() bool { return w&mask == exited } + +func (w WaitStatus) Signaled() bool { return w&mask != stopped && w&mask != exited } + +func (w WaitStatus) Stopped() bool { return w&0xFF == stopped } + +func (w WaitStatus) Continued() bool { return w == 0xFFFF } + +func (w WaitStatus) CoreDump() bool { return w.Signaled() && w&core != 0 } + +func (w WaitStatus) ExitStatus() int { + if !w.Exited() { + return -1 + } + return int(w>>shift) & 0xFF +} + +func (w WaitStatus) Signal() syscall.Signal { + if !w.Signaled() { + return -1 + } + return syscall.Signal(w & mask) +} + +func (w WaitStatus) StopSignal() syscall.Signal { + if !w.Stopped() { + return -1 + } + return syscall.Signal(w>>shift) & 0xFF +} + +func (w WaitStatus) TrapCause() int { + if w.StopSignal() != SIGTRAP { + return -1 + } + return int(w>>shift) >> 8 +} + +//sys wait4(pid int, wstatus *_C_int, options int, rusage *Rusage) (wpid int, err error) + +func Wait4(pid int, wstatus *WaitStatus, options int, rusage *Rusage) (wpid int, err error) { + var status _C_int + wpid, err = wait4(pid, &status, options, rusage) + if wstatus != nil { + *wstatus = WaitStatus(status) + } + return +} + +//sys Waitid(idType int, id int, info *Siginfo, options int, rusage *Rusage) (err error) + +func Mkfifo(path string, mode uint32) error { + return Mknod(path, mode|S_IFIFO, 0) +} + +func Mkfifoat(dirfd int, path string, mode uint32) error { + return Mknodat(dirfd, path, mode|S_IFIFO, 0) +} + +func (sa *SockaddrInet4) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Port < 0 || sa.Port > 0xFFFF { + return nil, 0, EINVAL + } + sa.raw.Family = AF_INET + p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port)) + p[0] = byte(sa.Port >> 8) + p[1] = byte(sa.Port) + sa.raw.Addr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrInet4, nil +} + +func (sa *SockaddrInet6) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Port < 0 || sa.Port > 0xFFFF { + return nil, 0, EINVAL + } + sa.raw.Family = AF_INET6 + p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port)) + p[0] = byte(sa.Port >> 8) + p[1] = byte(sa.Port) + sa.raw.Scope_id = sa.ZoneId + sa.raw.Addr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrInet6, nil +} + +func (sa *SockaddrUnix) sockaddr() (unsafe.Pointer, _Socklen, error) { + name := sa.Name + n := len(name) + if n >= len(sa.raw.Path) { + return nil, 0, EINVAL + } + sa.raw.Family = AF_UNIX + for i := 0; i < n; i++ { + sa.raw.Path[i] = int8(name[i]) + } + // length is family (uint16), name, NUL. + sl := _Socklen(2) + if n > 0 { + sl += _Socklen(n) + 1 + } + if sa.raw.Path[0] == '@' { + sa.raw.Path[0] = 0 + // Don't count trailing NUL for abstract address. + sl-- + } + + return unsafe.Pointer(&sa.raw), sl, nil +} + +// SockaddrLinklayer implements the Sockaddr interface for AF_PACKET type sockets. +type SockaddrLinklayer struct { + Protocol uint16 + Ifindex int + Hatype uint16 + Pkttype uint8 + Halen uint8 + Addr [8]byte + raw RawSockaddrLinklayer +} + +func (sa *SockaddrLinklayer) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Ifindex < 0 || sa.Ifindex > 0x7fffffff { + return nil, 0, EINVAL + } + sa.raw.Family = AF_PACKET + sa.raw.Protocol = sa.Protocol + sa.raw.Ifindex = int32(sa.Ifindex) + sa.raw.Hatype = sa.Hatype + sa.raw.Pkttype = sa.Pkttype + sa.raw.Halen = sa.Halen + sa.raw.Addr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrLinklayer, nil +} + +// SockaddrNetlink implements the Sockaddr interface for AF_NETLINK type sockets. +type SockaddrNetlink struct { + Family uint16 + Pad uint16 + Pid uint32 + Groups uint32 + raw RawSockaddrNetlink +} + +func (sa *SockaddrNetlink) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_NETLINK + sa.raw.Pad = sa.Pad + sa.raw.Pid = sa.Pid + sa.raw.Groups = sa.Groups + return unsafe.Pointer(&sa.raw), SizeofSockaddrNetlink, nil +} + +// SockaddrHCI implements the Sockaddr interface for AF_BLUETOOTH type sockets +// using the HCI protocol. +type SockaddrHCI struct { + Dev uint16 + Channel uint16 + raw RawSockaddrHCI +} + +func (sa *SockaddrHCI) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_BLUETOOTH + sa.raw.Dev = sa.Dev + sa.raw.Channel = sa.Channel + return unsafe.Pointer(&sa.raw), SizeofSockaddrHCI, nil +} + +// SockaddrL2 implements the Sockaddr interface for AF_BLUETOOTH type sockets +// using the L2CAP protocol. +type SockaddrL2 struct { + PSM uint16 + CID uint16 + Addr [6]uint8 + AddrType uint8 + raw RawSockaddrL2 +} + +func (sa *SockaddrL2) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_BLUETOOTH + psm := (*[2]byte)(unsafe.Pointer(&sa.raw.Psm)) + psm[0] = byte(sa.PSM) + psm[1] = byte(sa.PSM >> 8) + for i := 0; i < len(sa.Addr); i++ { + sa.raw.Bdaddr[i] = sa.Addr[len(sa.Addr)-1-i] + } + cid := (*[2]byte)(unsafe.Pointer(&sa.raw.Cid)) + cid[0] = byte(sa.CID) + cid[1] = byte(sa.CID >> 8) + sa.raw.Bdaddr_type = sa.AddrType + return unsafe.Pointer(&sa.raw), SizeofSockaddrL2, nil +} + +// SockaddrRFCOMM implements the Sockaddr interface for AF_BLUETOOTH type sockets +// using the RFCOMM protocol. +// +// Server example: +// +// fd, _ := Socket(AF_BLUETOOTH, SOCK_STREAM, BTPROTO_RFCOMM) +// _ = unix.Bind(fd, &unix.SockaddrRFCOMM{ +// Channel: 1, +// Addr: [6]uint8{0, 0, 0, 0, 0, 0}, // BDADDR_ANY or 00:00:00:00:00:00 +// }) +// _ = Listen(fd, 1) +// nfd, sa, _ := Accept(fd) +// fmt.Printf("conn addr=%v fd=%d", sa.(*unix.SockaddrRFCOMM).Addr, nfd) +// Read(nfd, buf) +// +// Client example: +// +// fd, _ := Socket(AF_BLUETOOTH, SOCK_STREAM, BTPROTO_RFCOMM) +// _ = Connect(fd, &SockaddrRFCOMM{ +// Channel: 1, +// Addr: [6]byte{0x11, 0x22, 0x33, 0xaa, 0xbb, 0xcc}, // CC:BB:AA:33:22:11 +// }) +// Write(fd, []byte(`hello`)) +type SockaddrRFCOMM struct { + // Addr represents a bluetooth address, byte ordering is little-endian. + Addr [6]uint8 + + // Channel is a designated bluetooth channel, only 1-30 are available for use. + // Since Linux 2.6.7 and further zero value is the first available channel. + Channel uint8 + + raw RawSockaddrRFCOMM +} + +func (sa *SockaddrRFCOMM) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_BLUETOOTH + sa.raw.Channel = sa.Channel + sa.raw.Bdaddr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrRFCOMM, nil +} + +// SockaddrCAN implements the Sockaddr interface for AF_CAN type sockets. +// The RxID and TxID fields are used for transport protocol addressing in +// (CAN_TP16, CAN_TP20, CAN_MCNET, and CAN_ISOTP), they can be left with +// zero values for CAN_RAW and CAN_BCM sockets as they have no meaning. +// +// The SockaddrCAN struct must be bound to the socket file descriptor +// using Bind before the CAN socket can be used. +// +// // Read one raw CAN frame +// fd, _ := Socket(AF_CAN, SOCK_RAW, CAN_RAW) +// addr := &SockaddrCAN{Ifindex: index} +// Bind(fd, addr) +// frame := make([]byte, 16) +// Read(fd, frame) +// +// The full SocketCAN documentation can be found in the linux kernel +// archives at: https://www.kernel.org/doc/Documentation/networking/can.txt +type SockaddrCAN struct { + Ifindex int + RxID uint32 + TxID uint32 + raw RawSockaddrCAN +} + +func (sa *SockaddrCAN) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Ifindex < 0 || sa.Ifindex > 0x7fffffff { + return nil, 0, EINVAL + } + sa.raw.Family = AF_CAN + sa.raw.Ifindex = int32(sa.Ifindex) + rx := (*[4]byte)(unsafe.Pointer(&sa.RxID)) + for i := 0; i < 4; i++ { + sa.raw.Addr[i] = rx[i] + } + tx := (*[4]byte)(unsafe.Pointer(&sa.TxID)) + for i := 0; i < 4; i++ { + sa.raw.Addr[i+4] = tx[i] + } + return unsafe.Pointer(&sa.raw), SizeofSockaddrCAN, nil +} + +// SockaddrCANJ1939 implements the Sockaddr interface for AF_CAN using J1939 +// protocol (https://en.wikipedia.org/wiki/SAE_J1939). For more information +// on the purposes of the fields, check the official linux kernel documentation +// available here: https://www.kernel.org/doc/Documentation/networking/j1939.rst +type SockaddrCANJ1939 struct { + Ifindex int + Name uint64 + PGN uint32 + Addr uint8 + raw RawSockaddrCAN +} + +func (sa *SockaddrCANJ1939) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Ifindex < 0 || sa.Ifindex > 0x7fffffff { + return nil, 0, EINVAL + } + sa.raw.Family = AF_CAN + sa.raw.Ifindex = int32(sa.Ifindex) + n := (*[8]byte)(unsafe.Pointer(&sa.Name)) + for i := 0; i < 8; i++ { + sa.raw.Addr[i] = n[i] + } + p := (*[4]byte)(unsafe.Pointer(&sa.PGN)) + for i := 0; i < 4; i++ { + sa.raw.Addr[i+8] = p[i] + } + sa.raw.Addr[12] = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrCAN, nil +} + +// SockaddrALG implements the Sockaddr interface for AF_ALG type sockets. +// SockaddrALG enables userspace access to the Linux kernel's cryptography +// subsystem. The Type and Name fields specify which type of hash or cipher +// should be used with a given socket. +// +// To create a file descriptor that provides access to a hash or cipher, both +// Bind and Accept must be used. Once the setup process is complete, input +// data can be written to the socket, processed by the kernel, and then read +// back as hash output or ciphertext. +// +// Here is an example of using an AF_ALG socket with SHA1 hashing. +// The initial socket setup process is as follows: +// +// // Open a socket to perform SHA1 hashing. +// fd, _ := unix.Socket(unix.AF_ALG, unix.SOCK_SEQPACKET, 0) +// addr := &unix.SockaddrALG{Type: "hash", Name: "sha1"} +// unix.Bind(fd, addr) +// // Note: unix.Accept does not work at this time; must invoke accept() +// // manually using unix.Syscall. +// hashfd, _, _ := unix.Syscall(unix.SYS_ACCEPT, uintptr(fd), 0, 0) +// +// Once a file descriptor has been returned from Accept, it may be used to +// perform SHA1 hashing. The descriptor is not safe for concurrent use, but +// may be re-used repeatedly with subsequent Write and Read operations. +// +// When hashing a small byte slice or string, a single Write and Read may +// be used: +// +// // Assume hashfd is already configured using the setup process. +// hash := os.NewFile(hashfd, "sha1") +// // Hash an input string and read the results. Each Write discards +// // previous hash state. Read always reads the current state. +// b := make([]byte, 20) +// for i := 0; i < 2; i++ { +// io.WriteString(hash, "Hello, world.") +// hash.Read(b) +// fmt.Println(hex.EncodeToString(b)) +// } +// // Output: +// // 2ae01472317d1935a84797ec1983ae243fc6aa28 +// // 2ae01472317d1935a84797ec1983ae243fc6aa28 +// +// For hashing larger byte slices, or byte streams such as those read from +// a file or socket, use Sendto with MSG_MORE to instruct the kernel to update +// the hash digest instead of creating a new one for a given chunk and finalizing it. +// +// // Assume hashfd and addr are already configured using the setup process. +// hash := os.NewFile(hashfd, "sha1") +// // Hash the contents of a file. +// f, _ := os.Open("/tmp/linux-4.10-rc7.tar.xz") +// b := make([]byte, 4096) +// for { +// n, err := f.Read(b) +// if err == io.EOF { +// break +// } +// unix.Sendto(hashfd, b[:n], unix.MSG_MORE, addr) +// } +// hash.Read(b) +// fmt.Println(hex.EncodeToString(b)) +// // Output: 85cdcad0c06eef66f805ecce353bec9accbeecc5 +// +// For more information, see: http://www.chronox.de/crypto-API/crypto/userspace-if.html. +type SockaddrALG struct { + Type string + Name string + Feature uint32 + Mask uint32 + raw RawSockaddrALG +} + +func (sa *SockaddrALG) sockaddr() (unsafe.Pointer, _Socklen, error) { + // Leave room for NUL byte terminator. + if len(sa.Type) > len(sa.raw.Type)-1 { + return nil, 0, EINVAL + } + if len(sa.Name) > len(sa.raw.Name)-1 { + return nil, 0, EINVAL + } + + sa.raw.Family = AF_ALG + sa.raw.Feat = sa.Feature + sa.raw.Mask = sa.Mask + + copy(sa.raw.Type[:], sa.Type) + copy(sa.raw.Name[:], sa.Name) + + return unsafe.Pointer(&sa.raw), SizeofSockaddrALG, nil +} + +// SockaddrVM implements the Sockaddr interface for AF_VSOCK type sockets. +// SockaddrVM provides access to Linux VM sockets: a mechanism that enables +// bidirectional communication between a hypervisor and its guest virtual +// machines. +type SockaddrVM struct { + // CID and Port specify a context ID and port address for a VM socket. + // Guests have a unique CID, and hosts may have a well-known CID of: + // - VMADDR_CID_HYPERVISOR: refers to the hypervisor process. + // - VMADDR_CID_LOCAL: refers to local communication (loopback). + // - VMADDR_CID_HOST: refers to other processes on the host. + CID uint32 + Port uint32 + Flags uint8 + raw RawSockaddrVM +} + +func (sa *SockaddrVM) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_VSOCK + sa.raw.Port = sa.Port + sa.raw.Cid = sa.CID + sa.raw.Flags = sa.Flags + + return unsafe.Pointer(&sa.raw), SizeofSockaddrVM, nil +} + +type SockaddrXDP struct { + Flags uint16 + Ifindex uint32 + QueueID uint32 + SharedUmemFD uint32 + raw RawSockaddrXDP +} + +func (sa *SockaddrXDP) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_XDP + sa.raw.Flags = sa.Flags + sa.raw.Ifindex = sa.Ifindex + sa.raw.Queue_id = sa.QueueID + sa.raw.Shared_umem_fd = sa.SharedUmemFD + + return unsafe.Pointer(&sa.raw), SizeofSockaddrXDP, nil +} + +// This constant mirrors the #define of PX_PROTO_OE in +// linux/if_pppox.h. We're defining this by hand here instead of +// autogenerating through mkerrors.sh because including +// linux/if_pppox.h causes some declaration conflicts with other +// includes (linux/if_pppox.h includes linux/in.h, which conflicts +// with netinet/in.h). Given that we only need a single zero constant +// out of that file, it's cleaner to just define it by hand here. +const px_proto_oe = 0 + +type SockaddrPPPoE struct { + SID uint16 + Remote []byte + Dev string + raw RawSockaddrPPPoX +} + +func (sa *SockaddrPPPoE) sockaddr() (unsafe.Pointer, _Socklen, error) { + if len(sa.Remote) != 6 { + return nil, 0, EINVAL + } + if len(sa.Dev) > IFNAMSIZ-1 { + return nil, 0, EINVAL + } + + *(*uint16)(unsafe.Pointer(&sa.raw[0])) = AF_PPPOX + // This next field is in host-endian byte order. We can't use the + // same unsafe pointer cast as above, because this value is not + // 32-bit aligned and some architectures don't allow unaligned + // access. + // + // However, the value of px_proto_oe is 0, so we can use + // encoding/binary helpers to write the bytes without worrying + // about the ordering. + binary.BigEndian.PutUint32(sa.raw[2:6], px_proto_oe) + // This field is deliberately big-endian, unlike the previous + // one. The kernel expects SID to be in network byte order. + binary.BigEndian.PutUint16(sa.raw[6:8], sa.SID) + copy(sa.raw[8:14], sa.Remote) + for i := 14; i < 14+IFNAMSIZ; i++ { + sa.raw[i] = 0 + } + copy(sa.raw[14:], sa.Dev) + return unsafe.Pointer(&sa.raw), SizeofSockaddrPPPoX, nil +} + +// SockaddrTIPC implements the Sockaddr interface for AF_TIPC type sockets. +// For more information on TIPC, see: http://tipc.sourceforge.net/. +type SockaddrTIPC struct { + // Scope is the publication scopes when binding service/service range. + // Should be set to TIPC_CLUSTER_SCOPE or TIPC_NODE_SCOPE. + Scope int + + // Addr is the type of address used to manipulate a socket. Addr must be + // one of: + // - *TIPCSocketAddr: "id" variant in the C addr union + // - *TIPCServiceRange: "nameseq" variant in the C addr union + // - *TIPCServiceName: "name" variant in the C addr union + // + // If nil, EINVAL will be returned when the structure is used. + Addr TIPCAddr + + raw RawSockaddrTIPC +} + +// TIPCAddr is implemented by types that can be used as an address for +// SockaddrTIPC. It is only implemented by *TIPCSocketAddr, *TIPCServiceRange, +// and *TIPCServiceName. +type TIPCAddr interface { + tipcAddrtype() uint8 + tipcAddr() [12]byte +} + +func (sa *TIPCSocketAddr) tipcAddr() [12]byte { + var out [12]byte + copy(out[:], (*(*[unsafe.Sizeof(TIPCSocketAddr{})]byte)(unsafe.Pointer(sa)))[:]) + return out +} + +func (sa *TIPCSocketAddr) tipcAddrtype() uint8 { return TIPC_SOCKET_ADDR } + +func (sa *TIPCServiceRange) tipcAddr() [12]byte { + var out [12]byte + copy(out[:], (*(*[unsafe.Sizeof(TIPCServiceRange{})]byte)(unsafe.Pointer(sa)))[:]) + return out +} + +func (sa *TIPCServiceRange) tipcAddrtype() uint8 { return TIPC_SERVICE_RANGE } + +func (sa *TIPCServiceName) tipcAddr() [12]byte { + var out [12]byte + copy(out[:], (*(*[unsafe.Sizeof(TIPCServiceName{})]byte)(unsafe.Pointer(sa)))[:]) + return out +} + +func (sa *TIPCServiceName) tipcAddrtype() uint8 { return TIPC_SERVICE_ADDR } + +func (sa *SockaddrTIPC) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Addr == nil { + return nil, 0, EINVAL + } + sa.raw.Family = AF_TIPC + sa.raw.Scope = int8(sa.Scope) + sa.raw.Addrtype = sa.Addr.tipcAddrtype() + sa.raw.Addr = sa.Addr.tipcAddr() + return unsafe.Pointer(&sa.raw), SizeofSockaddrTIPC, nil +} + +// SockaddrL2TPIP implements the Sockaddr interface for IPPROTO_L2TP/AF_INET sockets. +type SockaddrL2TPIP struct { + Addr [4]byte + ConnId uint32 + raw RawSockaddrL2TPIP +} + +func (sa *SockaddrL2TPIP) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_INET + sa.raw.Conn_id = sa.ConnId + sa.raw.Addr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrL2TPIP, nil +} + +// SockaddrL2TPIP6 implements the Sockaddr interface for IPPROTO_L2TP/AF_INET6 sockets. +type SockaddrL2TPIP6 struct { + Addr [16]byte + ZoneId uint32 + ConnId uint32 + raw RawSockaddrL2TPIP6 +} + +func (sa *SockaddrL2TPIP6) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_INET6 + sa.raw.Conn_id = sa.ConnId + sa.raw.Scope_id = sa.ZoneId + sa.raw.Addr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrL2TPIP6, nil +} + +// SockaddrIUCV implements the Sockaddr interface for AF_IUCV sockets. +type SockaddrIUCV struct { + UserID string + Name string + raw RawSockaddrIUCV +} + +func (sa *SockaddrIUCV) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Family = AF_IUCV + // These are EBCDIC encoded by the kernel, but we still need to pad them + // with blanks. Initializing with blanks allows the caller to feed in either + // a padded or an unpadded string. + for i := 0; i < 8; i++ { + sa.raw.Nodeid[i] = ' ' + sa.raw.User_id[i] = ' ' + sa.raw.Name[i] = ' ' + } + if len(sa.UserID) > 8 || len(sa.Name) > 8 { + return nil, 0, EINVAL + } + for i, b := range []byte(sa.UserID[:]) { + sa.raw.User_id[i] = int8(b) + } + for i, b := range []byte(sa.Name[:]) { + sa.raw.Name[i] = int8(b) + } + return unsafe.Pointer(&sa.raw), SizeofSockaddrIUCV, nil +} + +type SockaddrNFC struct { + DeviceIdx uint32 + TargetIdx uint32 + NFCProtocol uint32 + raw RawSockaddrNFC +} + +func (sa *SockaddrNFC) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Sa_family = AF_NFC + sa.raw.Dev_idx = sa.DeviceIdx + sa.raw.Target_idx = sa.TargetIdx + sa.raw.Nfc_protocol = sa.NFCProtocol + return unsafe.Pointer(&sa.raw), SizeofSockaddrNFC, nil +} + +type SockaddrNFCLLCP struct { + DeviceIdx uint32 + TargetIdx uint32 + NFCProtocol uint32 + DestinationSAP uint8 + SourceSAP uint8 + ServiceName string + raw RawSockaddrNFCLLCP +} + +func (sa *SockaddrNFCLLCP) sockaddr() (unsafe.Pointer, _Socklen, error) { + sa.raw.Sa_family = AF_NFC + sa.raw.Dev_idx = sa.DeviceIdx + sa.raw.Target_idx = sa.TargetIdx + sa.raw.Nfc_protocol = sa.NFCProtocol + sa.raw.Dsap = sa.DestinationSAP + sa.raw.Ssap = sa.SourceSAP + if len(sa.ServiceName) > len(sa.raw.Service_name) { + return nil, 0, EINVAL + } + copy(sa.raw.Service_name[:], sa.ServiceName) + sa.raw.SetServiceNameLen(len(sa.ServiceName)) + return unsafe.Pointer(&sa.raw), SizeofSockaddrNFCLLCP, nil +} + +var socketProtocol = func(fd int) (int, error) { + return GetsockoptInt(fd, SOL_SOCKET, SO_PROTOCOL) +} + +func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { + switch rsa.Addr.Family { + case AF_NETLINK: + pp := (*RawSockaddrNetlink)(unsafe.Pointer(rsa)) + sa := new(SockaddrNetlink) + sa.Family = pp.Family + sa.Pad = pp.Pad + sa.Pid = pp.Pid + sa.Groups = pp.Groups + return sa, nil + + case AF_PACKET: + pp := (*RawSockaddrLinklayer)(unsafe.Pointer(rsa)) + sa := new(SockaddrLinklayer) + sa.Protocol = pp.Protocol + sa.Ifindex = int(pp.Ifindex) + sa.Hatype = pp.Hatype + sa.Pkttype = pp.Pkttype + sa.Halen = pp.Halen + sa.Addr = pp.Addr + return sa, nil + + case AF_UNIX: + pp := (*RawSockaddrUnix)(unsafe.Pointer(rsa)) + sa := new(SockaddrUnix) + if pp.Path[0] == 0 { + // "Abstract" Unix domain socket. + // Rewrite leading NUL as @ for textual display. + // (This is the standard convention.) + // Not friendly to overwrite in place, + // but the callers below don't care. + pp.Path[0] = '@' + } + + // Assume path ends at NUL. + // This is not technically the Linux semantics for + // abstract Unix domain sockets--they are supposed + // to be uninterpreted fixed-size binary blobs--but + // everyone uses this convention. + n := 0 + for n < len(pp.Path) && pp.Path[n] != 0 { + n++ + } + sa.Name = string(unsafe.Slice((*byte)(unsafe.Pointer(&pp.Path[0])), n)) + return sa, nil + + case AF_INET: + proto, err := socketProtocol(fd) + if err != nil { + return nil, err + } + + switch proto { + case IPPROTO_L2TP: + pp := (*RawSockaddrL2TPIP)(unsafe.Pointer(rsa)) + sa := new(SockaddrL2TPIP) + sa.ConnId = pp.Conn_id + sa.Addr = pp.Addr + return sa, nil + default: + pp := (*RawSockaddrInet4)(unsafe.Pointer(rsa)) + sa := new(SockaddrInet4) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + sa.Port = int(p[0])<<8 + int(p[1]) + sa.Addr = pp.Addr + return sa, nil + } + + case AF_INET6: + proto, err := socketProtocol(fd) + if err != nil { + return nil, err + } + + switch proto { + case IPPROTO_L2TP: + pp := (*RawSockaddrL2TPIP6)(unsafe.Pointer(rsa)) + sa := new(SockaddrL2TPIP6) + sa.ConnId = pp.Conn_id + sa.ZoneId = pp.Scope_id + sa.Addr = pp.Addr + return sa, nil + default: + pp := (*RawSockaddrInet6)(unsafe.Pointer(rsa)) + sa := new(SockaddrInet6) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + sa.Port = int(p[0])<<8 + int(p[1]) + sa.ZoneId = pp.Scope_id + sa.Addr = pp.Addr + return sa, nil + } + + case AF_VSOCK: + pp := (*RawSockaddrVM)(unsafe.Pointer(rsa)) + sa := &SockaddrVM{ + CID: pp.Cid, + Port: pp.Port, + Flags: pp.Flags, + } + return sa, nil + case AF_BLUETOOTH: + proto, err := socketProtocol(fd) + if err != nil { + return nil, err + } + // only BTPROTO_L2CAP and BTPROTO_RFCOMM can accept connections + switch proto { + case BTPROTO_L2CAP: + pp := (*RawSockaddrL2)(unsafe.Pointer(rsa)) + sa := &SockaddrL2{ + PSM: pp.Psm, + CID: pp.Cid, + Addr: pp.Bdaddr, + AddrType: pp.Bdaddr_type, + } + return sa, nil + case BTPROTO_RFCOMM: + pp := (*RawSockaddrRFCOMM)(unsafe.Pointer(rsa)) + sa := &SockaddrRFCOMM{ + Channel: pp.Channel, + Addr: pp.Bdaddr, + } + return sa, nil + } + case AF_XDP: + pp := (*RawSockaddrXDP)(unsafe.Pointer(rsa)) + sa := &SockaddrXDP{ + Flags: pp.Flags, + Ifindex: pp.Ifindex, + QueueID: pp.Queue_id, + SharedUmemFD: pp.Shared_umem_fd, + } + return sa, nil + case AF_PPPOX: + pp := (*RawSockaddrPPPoX)(unsafe.Pointer(rsa)) + if binary.BigEndian.Uint32(pp[2:6]) != px_proto_oe { + return nil, EINVAL + } + sa := &SockaddrPPPoE{ + SID: binary.BigEndian.Uint16(pp[6:8]), + Remote: pp[8:14], + } + for i := 14; i < 14+IFNAMSIZ; i++ { + if pp[i] == 0 { + sa.Dev = string(pp[14:i]) + break + } + } + return sa, nil + case AF_TIPC: + pp := (*RawSockaddrTIPC)(unsafe.Pointer(rsa)) + + sa := &SockaddrTIPC{ + Scope: int(pp.Scope), + } + + // Determine which union variant is present in pp.Addr by checking + // pp.Addrtype. + switch pp.Addrtype { + case TIPC_SERVICE_RANGE: + sa.Addr = (*TIPCServiceRange)(unsafe.Pointer(&pp.Addr)) + case TIPC_SERVICE_ADDR: + sa.Addr = (*TIPCServiceName)(unsafe.Pointer(&pp.Addr)) + case TIPC_SOCKET_ADDR: + sa.Addr = (*TIPCSocketAddr)(unsafe.Pointer(&pp.Addr)) + default: + return nil, EINVAL + } + + return sa, nil + case AF_IUCV: + pp := (*RawSockaddrIUCV)(unsafe.Pointer(rsa)) + + var user [8]byte + var name [8]byte + + for i := 0; i < 8; i++ { + user[i] = byte(pp.User_id[i]) + name[i] = byte(pp.Name[i]) + } + + sa := &SockaddrIUCV{ + UserID: string(user[:]), + Name: string(name[:]), + } + return sa, nil + + case AF_CAN: + proto, err := socketProtocol(fd) + if err != nil { + return nil, err + } + + pp := (*RawSockaddrCAN)(unsafe.Pointer(rsa)) + + switch proto { + case CAN_J1939: + sa := &SockaddrCANJ1939{ + Ifindex: int(pp.Ifindex), + } + name := (*[8]byte)(unsafe.Pointer(&sa.Name)) + for i := 0; i < 8; i++ { + name[i] = pp.Addr[i] + } + pgn := (*[4]byte)(unsafe.Pointer(&sa.PGN)) + for i := 0; i < 4; i++ { + pgn[i] = pp.Addr[i+8] + } + addr := (*[1]byte)(unsafe.Pointer(&sa.Addr)) + addr[0] = pp.Addr[12] + return sa, nil + default: + sa := &SockaddrCAN{ + Ifindex: int(pp.Ifindex), + } + rx := (*[4]byte)(unsafe.Pointer(&sa.RxID)) + for i := 0; i < 4; i++ { + rx[i] = pp.Addr[i] + } + tx := (*[4]byte)(unsafe.Pointer(&sa.TxID)) + for i := 0; i < 4; i++ { + tx[i] = pp.Addr[i+4] + } + return sa, nil + } + case AF_NFC: + proto, err := socketProtocol(fd) + if err != nil { + return nil, err + } + switch proto { + case NFC_SOCKPROTO_RAW: + pp := (*RawSockaddrNFC)(unsafe.Pointer(rsa)) + sa := &SockaddrNFC{ + DeviceIdx: pp.Dev_idx, + TargetIdx: pp.Target_idx, + NFCProtocol: pp.Nfc_protocol, + } + return sa, nil + case NFC_SOCKPROTO_LLCP: + pp := (*RawSockaddrNFCLLCP)(unsafe.Pointer(rsa)) + if uint64(pp.Service_name_len) > uint64(len(pp.Service_name)) { + return nil, EINVAL + } + sa := &SockaddrNFCLLCP{ + DeviceIdx: pp.Dev_idx, + TargetIdx: pp.Target_idx, + NFCProtocol: pp.Nfc_protocol, + DestinationSAP: pp.Dsap, + SourceSAP: pp.Ssap, + ServiceName: string(pp.Service_name[:pp.Service_name_len]), + } + return sa, nil + default: + return nil, EINVAL + } + } + return nil, EAFNOSUPPORT +} + +func Accept(fd int) (nfd int, sa Sockaddr, err error) { + var rsa RawSockaddrAny + var len _Socklen = SizeofSockaddrAny + nfd, err = accept4(fd, &rsa, &len, 0) + if err != nil { + return + } + sa, err = anyToSockaddr(fd, &rsa) + if err != nil { + Close(nfd) + nfd = 0 + } + return +} + +func Accept4(fd int, flags int) (nfd int, sa Sockaddr, err error) { + var rsa RawSockaddrAny + var len _Socklen = SizeofSockaddrAny + nfd, err = accept4(fd, &rsa, &len, flags) + if err != nil { + return + } + if len > SizeofSockaddrAny { + panic("RawSockaddrAny too small") + } + sa, err = anyToSockaddr(fd, &rsa) + if err != nil { + Close(nfd) + nfd = 0 + } + return +} + +func Getsockname(fd int) (sa Sockaddr, err error) { + var rsa RawSockaddrAny + var len _Socklen = SizeofSockaddrAny + if err = getsockname(fd, &rsa, &len); err != nil { + return + } + return anyToSockaddr(fd, &rsa) +} + +func GetsockoptIPMreqn(fd, level, opt int) (*IPMreqn, error) { + var value IPMreqn + vallen := _Socklen(SizeofIPMreqn) + err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen) + return &value, err +} + +func GetsockoptUcred(fd, level, opt int) (*Ucred, error) { + var value Ucred + vallen := _Socklen(SizeofUcred) + err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen) + return &value, err +} + +func GetsockoptTCPInfo(fd, level, opt int) (*TCPInfo, error) { + var value TCPInfo + vallen := _Socklen(SizeofTCPInfo) + err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen) + return &value, err +} + +// GetsockoptString returns the string value of the socket option opt for the +// socket associated with fd at the given socket level. +func GetsockoptString(fd, level, opt int) (string, error) { + buf := make([]byte, 256) + vallen := _Socklen(len(buf)) + err := getsockopt(fd, level, opt, unsafe.Pointer(&buf[0]), &vallen) + if err != nil { + if err == ERANGE { + buf = make([]byte, vallen) + err = getsockopt(fd, level, opt, unsafe.Pointer(&buf[0]), &vallen) + } + if err != nil { + return "", err + } + } + return string(buf[:vallen-1]), nil +} + +func GetsockoptTpacketStats(fd, level, opt int) (*TpacketStats, error) { + var value TpacketStats + vallen := _Socklen(SizeofTpacketStats) + err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen) + return &value, err +} + +func GetsockoptTpacketStatsV3(fd, level, opt int) (*TpacketStatsV3, error) { + var value TpacketStatsV3 + vallen := _Socklen(SizeofTpacketStatsV3) + err := getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen) + return &value, err +} + +func SetsockoptIPMreqn(fd, level, opt int, mreq *IPMreqn) (err error) { + return setsockopt(fd, level, opt, unsafe.Pointer(mreq), unsafe.Sizeof(*mreq)) +} + +func SetsockoptPacketMreq(fd, level, opt int, mreq *PacketMreq) error { + return setsockopt(fd, level, opt, unsafe.Pointer(mreq), unsafe.Sizeof(*mreq)) +} + +// SetsockoptSockFprog attaches a classic BPF or an extended BPF program to a +// socket to filter incoming packets. See 'man 7 socket' for usage information. +func SetsockoptSockFprog(fd, level, opt int, fprog *SockFprog) error { + return setsockopt(fd, level, opt, unsafe.Pointer(fprog), unsafe.Sizeof(*fprog)) +} + +func SetsockoptCanRawFilter(fd, level, opt int, filter []CanFilter) error { + var p unsafe.Pointer + if len(filter) > 0 { + p = unsafe.Pointer(&filter[0]) + } + return setsockopt(fd, level, opt, p, uintptr(len(filter)*SizeofCanFilter)) +} + +func SetsockoptTpacketReq(fd, level, opt int, tp *TpacketReq) error { + return setsockopt(fd, level, opt, unsafe.Pointer(tp), unsafe.Sizeof(*tp)) +} + +func SetsockoptTpacketReq3(fd, level, opt int, tp *TpacketReq3) error { + return setsockopt(fd, level, opt, unsafe.Pointer(tp), unsafe.Sizeof(*tp)) +} + +func SetsockoptTCPRepairOpt(fd, level, opt int, o []TCPRepairOpt) (err error) { + if len(o) == 0 { + return EINVAL + } + return setsockopt(fd, level, opt, unsafe.Pointer(&o[0]), uintptr(SizeofTCPRepairOpt*len(o))) +} + +func SetsockoptTCPMD5Sig(fd, level, opt int, s *TCPMD5Sig) error { + return setsockopt(fd, level, opt, unsafe.Pointer(s), unsafe.Sizeof(*s)) +} + +// Keyctl Commands (http://man7.org/linux/man-pages/man2/keyctl.2.html) + +// KeyctlInt calls keyctl commands in which each argument is an int. +// These commands are KEYCTL_REVOKE, KEYCTL_CHOWN, KEYCTL_CLEAR, KEYCTL_LINK, +// KEYCTL_UNLINK, KEYCTL_NEGATE, KEYCTL_SET_REQKEY_KEYRING, KEYCTL_SET_TIMEOUT, +// KEYCTL_ASSUME_AUTHORITY, KEYCTL_SESSION_TO_PARENT, KEYCTL_REJECT, +// KEYCTL_INVALIDATE, and KEYCTL_GET_PERSISTENT. +//sys KeyctlInt(cmd int, arg2 int, arg3 int, arg4 int, arg5 int) (ret int, err error) = SYS_KEYCTL + +// KeyctlBuffer calls keyctl commands in which the third and fourth +// arguments are a buffer and its length, respectively. +// These commands are KEYCTL_UPDATE, KEYCTL_READ, and KEYCTL_INSTANTIATE. +//sys KeyctlBuffer(cmd int, arg2 int, buf []byte, arg5 int) (ret int, err error) = SYS_KEYCTL + +// KeyctlString calls keyctl commands which return a string. +// These commands are KEYCTL_DESCRIBE and KEYCTL_GET_SECURITY. +func KeyctlString(cmd int, id int) (string, error) { + // We must loop as the string data may change in between the syscalls. + // We could allocate a large buffer here to reduce the chance that the + // syscall needs to be called twice; however, this is unnecessary as + // the performance loss is negligible. + var buffer []byte + for { + // Try to fill the buffer with data + length, err := KeyctlBuffer(cmd, id, buffer, 0) + if err != nil { + return "", err + } + + // Check if the data was written + if length <= len(buffer) { + // Exclude the null terminator + return string(buffer[:length-1]), nil + } + + // Make a bigger buffer if needed + buffer = make([]byte, length) + } +} + +// Keyctl commands with special signatures. + +// KeyctlGetKeyringID implements the KEYCTL_GET_KEYRING_ID command. +// See the full documentation at: +// http://man7.org/linux/man-pages/man3/keyctl_get_keyring_ID.3.html +func KeyctlGetKeyringID(id int, create bool) (ringid int, err error) { + createInt := 0 + if create { + createInt = 1 + } + return KeyctlInt(KEYCTL_GET_KEYRING_ID, id, createInt, 0, 0) +} + +// KeyctlSetperm implements the KEYCTL_SETPERM command. The perm value is the +// key handle permission mask as described in the "keyctl setperm" section of +// http://man7.org/linux/man-pages/man1/keyctl.1.html. +// See the full documentation at: +// http://man7.org/linux/man-pages/man3/keyctl_setperm.3.html +func KeyctlSetperm(id int, perm uint32) error { + _, err := KeyctlInt(KEYCTL_SETPERM, id, int(perm), 0, 0) + return err +} + +//sys keyctlJoin(cmd int, arg2 string) (ret int, err error) = SYS_KEYCTL + +// KeyctlJoinSessionKeyring implements the KEYCTL_JOIN_SESSION_KEYRING command. +// See the full documentation at: +// http://man7.org/linux/man-pages/man3/keyctl_join_session_keyring.3.html +func KeyctlJoinSessionKeyring(name string) (ringid int, err error) { + return keyctlJoin(KEYCTL_JOIN_SESSION_KEYRING, name) +} + +//sys keyctlSearch(cmd int, arg2 int, arg3 string, arg4 string, arg5 int) (ret int, err error) = SYS_KEYCTL + +// KeyctlSearch implements the KEYCTL_SEARCH command. +// See the full documentation at: +// http://man7.org/linux/man-pages/man3/keyctl_search.3.html +func KeyctlSearch(ringid int, keyType, description string, destRingid int) (id int, err error) { + return keyctlSearch(KEYCTL_SEARCH, ringid, keyType, description, destRingid) +} + +//sys keyctlIOV(cmd int, arg2 int, payload []Iovec, arg5 int) (err error) = SYS_KEYCTL + +// KeyctlInstantiateIOV implements the KEYCTL_INSTANTIATE_IOV command. This +// command is similar to KEYCTL_INSTANTIATE, except that the payload is a slice +// of Iovec (each of which represents a buffer) instead of a single buffer. +// See the full documentation at: +// http://man7.org/linux/man-pages/man3/keyctl_instantiate_iov.3.html +func KeyctlInstantiateIOV(id int, payload []Iovec, ringid int) error { + return keyctlIOV(KEYCTL_INSTANTIATE_IOV, id, payload, ringid) +} + +//sys keyctlDH(cmd int, arg2 *KeyctlDHParams, buf []byte) (ret int, err error) = SYS_KEYCTL + +// KeyctlDHCompute implements the KEYCTL_DH_COMPUTE command. This command +// computes a Diffie-Hellman shared secret based on the provide params. The +// secret is written to the provided buffer and the returned size is the number +// of bytes written (returning an error if there is insufficient space in the +// buffer). If a nil buffer is passed in, this function returns the minimum +// buffer length needed to store the appropriate data. Note that this differs +// from KEYCTL_READ's behavior which always returns the requested payload size. +// See the full documentation at: +// http://man7.org/linux/man-pages/man3/keyctl_dh_compute.3.html +func KeyctlDHCompute(params *KeyctlDHParams, buffer []byte) (size int, err error) { + return keyctlDH(KEYCTL_DH_COMPUTE, params, buffer) +} + +// KeyctlRestrictKeyring implements the KEYCTL_RESTRICT_KEYRING command. This +// command limits the set of keys that can be linked to the keyring, regardless +// of keyring permissions. The command requires the "setattr" permission. +// +// When called with an empty keyType the command locks the keyring, preventing +// any further keys from being linked to the keyring. +// +// The "asymmetric" keyType defines restrictions requiring key payloads to be +// DER encoded X.509 certificates signed by keys in another keyring. Restrictions +// for "asymmetric" include "builtin_trusted", "builtin_and_secondary_trusted", +// "key_or_keyring:", and "key_or_keyring::chain". +// +// As of Linux 4.12, only the "asymmetric" keyType defines type-specific +// restrictions. +// +// See the full documentation at: +// http://man7.org/linux/man-pages/man3/keyctl_restrict_keyring.3.html +// http://man7.org/linux/man-pages/man2/keyctl.2.html +func KeyctlRestrictKeyring(ringid int, keyType string, restriction string) error { + if keyType == "" { + return keyctlRestrictKeyring(KEYCTL_RESTRICT_KEYRING, ringid) + } + return keyctlRestrictKeyringByType(KEYCTL_RESTRICT_KEYRING, ringid, keyType, restriction) +} + +//sys keyctlRestrictKeyringByType(cmd int, arg2 int, keyType string, restriction string) (err error) = SYS_KEYCTL +//sys keyctlRestrictKeyring(cmd int, arg2 int) (err error) = SYS_KEYCTL + +func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { + var msg Msghdr + msg.Name = (*byte)(unsafe.Pointer(rsa)) + msg.Namelen = uint32(SizeofSockaddrAny) + var dummy byte + if len(oob) > 0 { + if emptyIovecs(iov) { + var sockType int + sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE) + if err != nil { + return + } + // receive at least one normal byte + if sockType != SOCK_DGRAM { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] + } + } + msg.Control = &oob[0] + msg.SetControllen(len(oob)) + } + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } + if n, err = recvmsg(fd, &msg, flags); err != nil { + return + } + oobn = int(msg.Controllen) + recvflags = int(msg.Flags) + return +} + +func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { + var msg Msghdr + msg.Name = (*byte)(ptr) + msg.Namelen = uint32(salen) + var dummy byte + var empty bool + if len(oob) > 0 { + empty = emptyIovecs(iov) + if empty { + var sockType int + sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE) + if err != nil { + return 0, err + } + // send at least one normal byte + if sockType != SOCK_DGRAM { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] + } + } + msg.Control = &oob[0] + msg.SetControllen(len(oob)) + } + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } + if n, err = sendmsg(fd, &msg, flags); err != nil { + return 0, err + } + if len(oob) > 0 && empty { + n = 0 + } + return n, nil +} + +// BindToDevice binds the socket associated with fd to device. +func BindToDevice(fd int, device string) (err error) { + return SetsockoptString(fd, SOL_SOCKET, SO_BINDTODEVICE, device) +} + +//sys ptrace(request int, pid int, addr uintptr, data uintptr) (err error) +//sys ptracePtr(request int, pid int, addr uintptr, data unsafe.Pointer) (err error) = SYS_PTRACE + +func ptracePeek(req int, pid int, addr uintptr, out []byte) (count int, err error) { + // The peek requests are machine-size oriented, so we wrap it + // to retrieve arbitrary-length data. + + // The ptrace syscall differs from glibc's ptrace. + // Peeks returns the word in *data, not as the return value. + + var buf [SizeofPtr]byte + + // Leading edge. PEEKTEXT/PEEKDATA don't require aligned + // access (PEEKUSER warns that it might), but if we don't + // align our reads, we might straddle an unmapped page + // boundary and not get the bytes leading up to the page + // boundary. + n := 0 + if addr%SizeofPtr != 0 { + err = ptracePtr(req, pid, addr-addr%SizeofPtr, unsafe.Pointer(&buf[0])) + if err != nil { + return 0, err + } + n += copy(out, buf[addr%SizeofPtr:]) + out = out[n:] + } + + // Remainder. + for len(out) > 0 { + // We use an internal buffer to guarantee alignment. + // It's not documented if this is necessary, but we're paranoid. + err = ptracePtr(req, pid, addr+uintptr(n), unsafe.Pointer(&buf[0])) + if err != nil { + return n, err + } + copied := copy(out, buf[0:]) + n += copied + out = out[copied:] + } + + return n, nil +} + +func PtracePeekText(pid int, addr uintptr, out []byte) (count int, err error) { + return ptracePeek(PTRACE_PEEKTEXT, pid, addr, out) +} + +func PtracePeekData(pid int, addr uintptr, out []byte) (count int, err error) { + return ptracePeek(PTRACE_PEEKDATA, pid, addr, out) +} + +func PtracePeekUser(pid int, addr uintptr, out []byte) (count int, err error) { + return ptracePeek(PTRACE_PEEKUSR, pid, addr, out) +} + +func ptracePoke(pokeReq int, peekReq int, pid int, addr uintptr, data []byte) (count int, err error) { + // As for ptracePeek, we need to align our accesses to deal + // with the possibility of straddling an invalid page. + + // Leading edge. + n := 0 + if addr%SizeofPtr != 0 { + var buf [SizeofPtr]byte + err = ptracePtr(peekReq, pid, addr-addr%SizeofPtr, unsafe.Pointer(&buf[0])) + if err != nil { + return 0, err + } + n += copy(buf[addr%SizeofPtr:], data) + word := *((*uintptr)(unsafe.Pointer(&buf[0]))) + err = ptrace(pokeReq, pid, addr-addr%SizeofPtr, word) + if err != nil { + return 0, err + } + data = data[n:] + } + + // Interior. + for len(data) > SizeofPtr { + word := *((*uintptr)(unsafe.Pointer(&data[0]))) + err = ptrace(pokeReq, pid, addr+uintptr(n), word) + if err != nil { + return n, err + } + n += SizeofPtr + data = data[SizeofPtr:] + } + + // Trailing edge. + if len(data) > 0 { + var buf [SizeofPtr]byte + err = ptracePtr(peekReq, pid, addr+uintptr(n), unsafe.Pointer(&buf[0])) + if err != nil { + return n, err + } + copy(buf[0:], data) + word := *((*uintptr)(unsafe.Pointer(&buf[0]))) + err = ptrace(pokeReq, pid, addr+uintptr(n), word) + if err != nil { + return n, err + } + n += len(data) + } + + return n, nil +} + +func PtracePokeText(pid int, addr uintptr, data []byte) (count int, err error) { + return ptracePoke(PTRACE_POKETEXT, PTRACE_PEEKTEXT, pid, addr, data) +} + +func PtracePokeData(pid int, addr uintptr, data []byte) (count int, err error) { + return ptracePoke(PTRACE_POKEDATA, PTRACE_PEEKDATA, pid, addr, data) +} + +func PtracePokeUser(pid int, addr uintptr, data []byte) (count int, err error) { + return ptracePoke(PTRACE_POKEUSR, PTRACE_PEEKUSR, pid, addr, data) +} + +// elfNT_PRSTATUS is a copy of the debug/elf.NT_PRSTATUS constant so +// x/sys/unix doesn't need to depend on debug/elf and thus +// compress/zlib, debug/dwarf, and other packages. +const elfNT_PRSTATUS = 1 + +func PtraceGetRegs(pid int, regsout *PtraceRegs) (err error) { + var iov Iovec + iov.Base = (*byte)(unsafe.Pointer(regsout)) + iov.SetLen(int(unsafe.Sizeof(*regsout))) + return ptracePtr(PTRACE_GETREGSET, pid, uintptr(elfNT_PRSTATUS), unsafe.Pointer(&iov)) +} + +func PtraceSetRegs(pid int, regs *PtraceRegs) (err error) { + var iov Iovec + iov.Base = (*byte)(unsafe.Pointer(regs)) + iov.SetLen(int(unsafe.Sizeof(*regs))) + return ptracePtr(PTRACE_SETREGSET, pid, uintptr(elfNT_PRSTATUS), unsafe.Pointer(&iov)) +} + +func PtraceSetOptions(pid int, options int) (err error) { + return ptrace(PTRACE_SETOPTIONS, pid, 0, uintptr(options)) +} + +func PtraceGetEventMsg(pid int) (msg uint, err error) { + var data _C_long + err = ptracePtr(PTRACE_GETEVENTMSG, pid, 0, unsafe.Pointer(&data)) + msg = uint(data) + return +} + +func PtraceCont(pid int, signal int) (err error) { + return ptrace(PTRACE_CONT, pid, 0, uintptr(signal)) +} + +func PtraceSyscall(pid int, signal int) (err error) { + return ptrace(PTRACE_SYSCALL, pid, 0, uintptr(signal)) +} + +func PtraceSingleStep(pid int) (err error) { return ptrace(PTRACE_SINGLESTEP, pid, 0, 0) } + +func PtraceInterrupt(pid int) (err error) { return ptrace(PTRACE_INTERRUPT, pid, 0, 0) } + +func PtraceAttach(pid int) (err error) { return ptrace(PTRACE_ATTACH, pid, 0, 0) } + +func PtraceSeize(pid int) (err error) { return ptrace(PTRACE_SEIZE, pid, 0, 0) } + +func PtraceDetach(pid int) (err error) { return ptrace(PTRACE_DETACH, pid, 0, 0) } + +//sys reboot(magic1 uint, magic2 uint, cmd int, arg string) (err error) + +func Reboot(cmd int) (err error) { + return reboot(LINUX_REBOOT_MAGIC1, LINUX_REBOOT_MAGIC2, cmd, "") +} + +func direntIno(buf []byte) (uint64, bool) { + return readInt(buf, unsafe.Offsetof(Dirent{}.Ino), unsafe.Sizeof(Dirent{}.Ino)) +} + +func direntReclen(buf []byte) (uint64, bool) { + return readInt(buf, unsafe.Offsetof(Dirent{}.Reclen), unsafe.Sizeof(Dirent{}.Reclen)) +} + +func direntNamlen(buf []byte) (uint64, bool) { + reclen, ok := direntReclen(buf) + if !ok { + return 0, false + } + return reclen - uint64(unsafe.Offsetof(Dirent{}.Name)), true +} + +//sys mount(source string, target string, fstype string, flags uintptr, data *byte) (err error) + +func Mount(source string, target string, fstype string, flags uintptr, data string) (err error) { + // Certain file systems get rather angry and EINVAL if you give + // them an empty string of data, rather than NULL. + if data == "" { + return mount(source, target, fstype, flags, nil) + } + datap, err := BytePtrFromString(data) + if err != nil { + return err + } + return mount(source, target, fstype, flags, datap) +} + +//sys mountSetattr(dirfd int, pathname string, flags uint, attr *MountAttr, size uintptr) (err error) = SYS_MOUNT_SETATTR + +// MountSetattr is a wrapper for mount_setattr(2). +// https://man7.org/linux/man-pages/man2/mount_setattr.2.html +// +// Requires kernel >= 5.12. +func MountSetattr(dirfd int, pathname string, flags uint, attr *MountAttr) error { + return mountSetattr(dirfd, pathname, flags, attr, unsafe.Sizeof(*attr)) +} + +func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) { + if raceenabled { + raceReleaseMerge(unsafe.Pointer(&ioSync)) + } + return sendfile(outfd, infd, offset, count) +} + +// Sendto +// Recvfrom +// Socketpair + +/* + * Direct access + */ +//sys Acct(path string) (err error) +//sys AddKey(keyType string, description string, payload []byte, ringid int) (id int, err error) +//sys Adjtimex(buf *Timex) (state int, err error) +//sysnb Capget(hdr *CapUserHeader, data *CapUserData) (err error) +//sysnb Capset(hdr *CapUserHeader, data *CapUserData) (err error) +//sys Chdir(path string) (err error) +//sys Chroot(path string) (err error) +//sys ClockAdjtime(clockid int32, buf *Timex) (state int, err error) +//sys ClockGetres(clockid int32, res *Timespec) (err error) +//sys ClockGettime(clockid int32, time *Timespec) (err error) +//sys ClockNanosleep(clockid int32, flags int, request *Timespec, remain *Timespec) (err error) +//sys Close(fd int) (err error) +//sys CloseRange(first uint, last uint, flags uint) (err error) +//sys CopyFileRange(rfd int, roff *int64, wfd int, woff *int64, len int, flags int) (n int, err error) +//sys DeleteModule(name string, flags int) (err error) +//sys Dup(oldfd int) (fd int, err error) + +func Dup2(oldfd, newfd int) error { + return Dup3(oldfd, newfd, 0) +} + +//sys Dup3(oldfd int, newfd int, flags int) (err error) +//sysnb EpollCreate1(flag int) (fd int, err error) +//sysnb EpollCtl(epfd int, op int, fd int, event *EpollEvent) (err error) +//sys Eventfd(initval uint, flags int) (fd int, err error) = SYS_EVENTFD2 +//sys Exit(code int) = SYS_EXIT_GROUP +//sys Fallocate(fd int, mode uint32, off int64, len int64) (err error) +//sys Fchdir(fd int) (err error) +//sys Fchmod(fd int, mode uint32) (err error) +//sys Fchownat(dirfd int, path string, uid int, gid int, flags int) (err error) +//sys Fdatasync(fd int) (err error) +//sys Fgetxattr(fd int, attr string, dest []byte) (sz int, err error) +//sys FinitModule(fd int, params string, flags int) (err error) +//sys Flistxattr(fd int, dest []byte) (sz int, err error) +//sys Flock(fd int, how int) (err error) +//sys Fremovexattr(fd int, attr string) (err error) +//sys Fsetxattr(fd int, attr string, dest []byte, flags int) (err error) +//sys Fsync(fd int) (err error) +//sys Fsmount(fd int, flags int, mountAttrs int) (fsfd int, err error) +//sys Fsopen(fsName string, flags int) (fd int, err error) +//sys Fspick(dirfd int, pathName string, flags int) (fd int, err error) +//sys Getdents(fd int, buf []byte) (n int, err error) = SYS_GETDENTS64 +//sysnb Getpgid(pid int) (pgid int, err error) + +func Getpgrp() (pid int) { + pid, _ = Getpgid(0) + return +} + +//sysnb Getpid() (pid int) +//sysnb Getppid() (ppid int) +//sys Getpriority(which int, who int) (prio int, err error) +//sys Getrandom(buf []byte, flags int) (n int, err error) +//sysnb Getrusage(who int, rusage *Rusage) (err error) +//sysnb Getsid(pid int) (sid int, err error) +//sysnb Gettid() (tid int) +//sys Getxattr(path string, attr string, dest []byte) (sz int, err error) +//sys InitModule(moduleImage []byte, params string) (err error) +//sys InotifyAddWatch(fd int, pathname string, mask uint32) (watchdesc int, err error) +//sysnb InotifyInit1(flags int) (fd int, err error) +//sysnb InotifyRmWatch(fd int, watchdesc uint32) (success int, err error) +//sysnb Kill(pid int, sig syscall.Signal) (err error) +//sys Klogctl(typ int, buf []byte) (n int, err error) = SYS_SYSLOG +//sys Lgetxattr(path string, attr string, dest []byte) (sz int, err error) +//sys Listxattr(path string, dest []byte) (sz int, err error) +//sys Llistxattr(path string, dest []byte) (sz int, err error) +//sys Lremovexattr(path string, attr string) (err error) +//sys Lsetxattr(path string, attr string, data []byte, flags int) (err error) +//sys MemfdCreate(name string, flags int) (fd int, err error) +//sys Mkdirat(dirfd int, path string, mode uint32) (err error) +//sys Mknodat(dirfd int, path string, mode uint32, dev int) (err error) +//sys MoveMount(fromDirfd int, fromPathName string, toDirfd int, toPathName string, flags int) (err error) +//sys Nanosleep(time *Timespec, leftover *Timespec) (err error) +//sys OpenTree(dfd int, fileName string, flags uint) (r int, err error) +//sys PerfEventOpen(attr *PerfEventAttr, pid int, cpu int, groupFd int, flags int) (fd int, err error) +//sys PivotRoot(newroot string, putold string) (err error) = SYS_PIVOT_ROOT +//sys Prctl(option int, arg2 uintptr, arg3 uintptr, arg4 uintptr, arg5 uintptr) (err error) +//sys pselect6(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timespec, sigmask *sigset_argpack) (n int, err error) +//sys read(fd int, p []byte) (n int, err error) +//sys Removexattr(path string, attr string) (err error) +//sys Renameat2(olddirfd int, oldpath string, newdirfd int, newpath string, flags uint) (err error) +//sys RequestKey(keyType string, description string, callback string, destRingid int) (id int, err error) +//sys Setdomainname(p []byte) (err error) +//sys Sethostname(p []byte) (err error) +//sysnb Setpgid(pid int, pgid int) (err error) +//sysnb Setsid() (pid int, err error) +//sysnb Settimeofday(tv *Timeval) (err error) +//sys Setns(fd int, nstype int) (err error) + +//go:linkname syscall_prlimit syscall.prlimit +func syscall_prlimit(pid, resource int, newlimit, old *syscall.Rlimit) error + +func Prlimit(pid, resource int, newlimit, old *Rlimit) error { + // Just call the syscall version, because as of Go 1.21 + // it will affect starting a new process. + return syscall_prlimit(pid, resource, (*syscall.Rlimit)(newlimit), (*syscall.Rlimit)(old)) +} + +// PrctlRetInt performs a prctl operation specified by option and further +// optional arguments arg2 through arg5 depending on option. It returns a +// non-negative integer that is returned by the prctl syscall. +func PrctlRetInt(option int, arg2 uintptr, arg3 uintptr, arg4 uintptr, arg5 uintptr) (int, error) { + ret, _, err := Syscall6(SYS_PRCTL, uintptr(option), uintptr(arg2), uintptr(arg3), uintptr(arg4), uintptr(arg5), 0) + if err != 0 { + return 0, err + } + return int(ret), nil +} + +func Setuid(uid int) (err error) { + return syscall.Setuid(uid) +} + +func Setgid(gid int) (err error) { + return syscall.Setgid(gid) +} + +func Setreuid(ruid, euid int) (err error) { + return syscall.Setreuid(ruid, euid) +} + +func Setregid(rgid, egid int) (err error) { + return syscall.Setregid(rgid, egid) +} + +func Setresuid(ruid, euid, suid int) (err error) { + return syscall.Setresuid(ruid, euid, suid) +} + +func Setresgid(rgid, egid, sgid int) (err error) { + return syscall.Setresgid(rgid, egid, sgid) +} + +// SetfsgidRetGid sets fsgid for current thread and returns previous fsgid set. +// setfsgid(2) will return a non-nil error only if its caller lacks CAP_SETUID capability. +// If the call fails due to other reasons, current fsgid will be returned. +func SetfsgidRetGid(gid int) (int, error) { + return setfsgid(gid) +} + +// SetfsuidRetUid sets fsuid for current thread and returns previous fsuid set. +// setfsgid(2) will return a non-nil error only if its caller lacks CAP_SETUID capability +// If the call fails due to other reasons, current fsuid will be returned. +func SetfsuidRetUid(uid int) (int, error) { + return setfsuid(uid) +} + +func Setfsgid(gid int) error { + _, err := setfsgid(gid) + return err +} + +func Setfsuid(uid int) error { + _, err := setfsuid(uid) + return err +} + +func Signalfd(fd int, sigmask *Sigset_t, flags int) (newfd int, err error) { + return signalfd(fd, sigmask, _C__NSIG/8, flags) +} + +//sys Setpriority(which int, who int, prio int) (err error) +//sys Setxattr(path string, attr string, data []byte, flags int) (err error) +//sys signalfd(fd int, sigmask *Sigset_t, maskSize uintptr, flags int) (newfd int, err error) = SYS_SIGNALFD4 +//sys Statx(dirfd int, path string, flags int, mask int, stat *Statx_t) (err error) +//sys Sync() +//sys Syncfs(fd int) (err error) +//sysnb Sysinfo(info *Sysinfo_t) (err error) +//sys Tee(rfd int, wfd int, len int, flags int) (n int64, err error) +//sysnb TimerfdCreate(clockid int, flags int) (fd int, err error) +//sysnb TimerfdGettime(fd int, currValue *ItimerSpec) (err error) +//sysnb TimerfdSettime(fd int, flags int, newValue *ItimerSpec, oldValue *ItimerSpec) (err error) +//sysnb Tgkill(tgid int, tid int, sig syscall.Signal) (err error) +//sysnb Times(tms *Tms) (ticks uintptr, err error) +//sysnb Umask(mask int) (oldmask int) +//sysnb Uname(buf *Utsname) (err error) +//sys Unmount(target string, flags int) (err error) = SYS_UMOUNT2 +//sys Unshare(flags int) (err error) +//sys write(fd int, p []byte) (n int, err error) +//sys exitThread(code int) (err error) = SYS_EXIT +//sys readv(fd int, iovs []Iovec) (n int, err error) = SYS_READV +//sys writev(fd int, iovs []Iovec) (n int, err error) = SYS_WRITEV +//sys preadv(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr) (n int, err error) = SYS_PREADV +//sys pwritev(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr) (n int, err error) = SYS_PWRITEV +//sys preadv2(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr, flags int) (n int, err error) = SYS_PREADV2 +//sys pwritev2(fd int, iovs []Iovec, offs_l uintptr, offs_h uintptr, flags int) (n int, err error) = SYS_PWRITEV2 + +// minIovec is the size of the small initial allocation used by +// Readv, Writev, etc. +// +// This small allocation gets stack allocated, which lets the +// common use case of len(iovs) <= minIovs avoid more expensive +// heap allocations. +const minIovec = 8 + +// appendBytes converts bs to Iovecs and appends them to vecs. +func appendBytes(vecs []Iovec, bs [][]byte) []Iovec { + for _, b := range bs { + var v Iovec + v.SetLen(len(b)) + if len(b) > 0 { + v.Base = &b[0] + } else { + v.Base = (*byte)(unsafe.Pointer(&_zero)) + } + vecs = append(vecs, v) + } + return vecs +} + +// offs2lohi splits offs into its low and high order bits. +func offs2lohi(offs int64) (lo, hi uintptr) { + const longBits = SizeofLong * 8 + return uintptr(offs), uintptr(uint64(offs) >> (longBits - 1) >> 1) // two shifts to avoid false positive in vet +} + +func Readv(fd int, iovs [][]byte) (n int, err error) { + iovecs := make([]Iovec, 0, minIovec) + iovecs = appendBytes(iovecs, iovs) + n, err = readv(fd, iovecs) + readvRacedetect(iovecs, n, err) + return n, err +} + +func Preadv(fd int, iovs [][]byte, offset int64) (n int, err error) { + iovecs := make([]Iovec, 0, minIovec) + iovecs = appendBytes(iovecs, iovs) + lo, hi := offs2lohi(offset) + n, err = preadv(fd, iovecs, lo, hi) + readvRacedetect(iovecs, n, err) + return n, err +} + +func Preadv2(fd int, iovs [][]byte, offset int64, flags int) (n int, err error) { + iovecs := make([]Iovec, 0, minIovec) + iovecs = appendBytes(iovecs, iovs) + lo, hi := offs2lohi(offset) + n, err = preadv2(fd, iovecs, lo, hi, flags) + readvRacedetect(iovecs, n, err) + return n, err +} + +func readvRacedetect(iovecs []Iovec, n int, err error) { + if !raceenabled { + return + } + for i := 0; n > 0 && i < len(iovecs); i++ { + m := int(iovecs[i].Len) + if m > n { + m = n + } + n -= m + if m > 0 { + raceWriteRange(unsafe.Pointer(iovecs[i].Base), m) + } + } + if err == nil { + raceAcquire(unsafe.Pointer(&ioSync)) + } +} + +func Writev(fd int, iovs [][]byte) (n int, err error) { + iovecs := make([]Iovec, 0, minIovec) + iovecs = appendBytes(iovecs, iovs) + if raceenabled { + raceReleaseMerge(unsafe.Pointer(&ioSync)) + } + n, err = writev(fd, iovecs) + writevRacedetect(iovecs, n) + return n, err +} + +func Pwritev(fd int, iovs [][]byte, offset int64) (n int, err error) { + iovecs := make([]Iovec, 0, minIovec) + iovecs = appendBytes(iovecs, iovs) + if raceenabled { + raceReleaseMerge(unsafe.Pointer(&ioSync)) + } + lo, hi := offs2lohi(offset) + n, err = pwritev(fd, iovecs, lo, hi) + writevRacedetect(iovecs, n) + return n, err +} + +func Pwritev2(fd int, iovs [][]byte, offset int64, flags int) (n int, err error) { + iovecs := make([]Iovec, 0, minIovec) + iovecs = appendBytes(iovecs, iovs) + if raceenabled { + raceReleaseMerge(unsafe.Pointer(&ioSync)) + } + lo, hi := offs2lohi(offset) + n, err = pwritev2(fd, iovecs, lo, hi, flags) + writevRacedetect(iovecs, n) + return n, err +} + +func writevRacedetect(iovecs []Iovec, n int) { + if !raceenabled { + return + } + for i := 0; n > 0 && i < len(iovecs); i++ { + m := int(iovecs[i].Len) + if m > n { + m = n + } + n -= m + if m > 0 { + raceReadRange(unsafe.Pointer(iovecs[i].Base), m) + } + } +} + +// mmap varies by architecture; see syscall_linux_*.go. +//sys munmap(addr uintptr, length uintptr) (err error) +//sys mremap(oldaddr uintptr, oldlength uintptr, newlength uintptr, flags int, newaddr uintptr) (xaddr uintptr, err error) +//sys Madvise(b []byte, advice int) (err error) +//sys Mprotect(b []byte, prot int) (err error) +//sys Mlock(b []byte) (err error) +//sys Mlockall(flags int) (err error) +//sys Msync(b []byte, flags int) (err error) +//sys Munlock(b []byte) (err error) +//sys Munlockall() (err error) + +const ( + mremapFixed = MREMAP_FIXED + mremapDontunmap = MREMAP_DONTUNMAP + mremapMaymove = MREMAP_MAYMOVE +) + +// Vmsplice splices user pages from a slice of Iovecs into a pipe specified by fd, +// using the specified flags. +func Vmsplice(fd int, iovs []Iovec, flags int) (int, error) { + var p unsafe.Pointer + if len(iovs) > 0 { + p = unsafe.Pointer(&iovs[0]) + } + + n, _, errno := Syscall6(SYS_VMSPLICE, uintptr(fd), uintptr(p), uintptr(len(iovs)), uintptr(flags), 0, 0) + if errno != 0 { + return 0, syscall.Errno(errno) + } + + return int(n), nil +} + +func isGroupMember(gid int) bool { + groups, err := Getgroups() + if err != nil { + return false + } + + for _, g := range groups { + if g == gid { + return true + } + } + return false +} + +func isCapDacOverrideSet() bool { + hdr := CapUserHeader{Version: LINUX_CAPABILITY_VERSION_3} + data := [2]CapUserData{} + err := Capget(&hdr, &data[0]) + + return err == nil && data[0].Effective&(1<> 6) & 7 + } else { + var gid int + if flags&AT_EACCESS != 0 { + gid = Getegid() + } else { + gid = Getgid() + } + + if uint32(gid) == st.Gid || isGroupMember(int(st.Gid)) { + fmode = (st.Mode >> 3) & 7 + } else { + fmode = st.Mode & 7 + } + } + + if fmode&mode == mode { + return nil + } + + return EACCES +} + +//sys nameToHandleAt(dirFD int, pathname string, fh *fileHandle, mountID *_C_int, flags int) (err error) = SYS_NAME_TO_HANDLE_AT +//sys openByHandleAt(mountFD int, fh *fileHandle, flags int) (fd int, err error) = SYS_OPEN_BY_HANDLE_AT + +// fileHandle is the argument to nameToHandleAt and openByHandleAt. We +// originally tried to generate it via unix/linux/types.go with "type +// fileHandle C.struct_file_handle" but that generated empty structs +// for mips64 and mips64le. Instead, hard code it for now (it's the +// same everywhere else) until the mips64 generator issue is fixed. +type fileHandle struct { + Bytes uint32 + Type int32 +} + +// FileHandle represents the C struct file_handle used by +// name_to_handle_at (see NameToHandleAt) and open_by_handle_at (see +// OpenByHandleAt). +type FileHandle struct { + *fileHandle +} + +// NewFileHandle constructs a FileHandle. +func NewFileHandle(handleType int32, handle []byte) FileHandle { + const hdrSize = unsafe.Sizeof(fileHandle{}) + buf := make([]byte, hdrSize+uintptr(len(handle))) + copy(buf[hdrSize:], handle) + fh := (*fileHandle)(unsafe.Pointer(&buf[0])) + fh.Type = handleType + fh.Bytes = uint32(len(handle)) + return FileHandle{fh} +} + +func (fh *FileHandle) Size() int { return int(fh.fileHandle.Bytes) } +func (fh *FileHandle) Type() int32 { return fh.fileHandle.Type } +func (fh *FileHandle) Bytes() []byte { + n := fh.Size() + if n == 0 { + return nil + } + return unsafe.Slice((*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&fh.fileHandle.Type))+4)), n) +} + +// NameToHandleAt wraps the name_to_handle_at system call; it obtains +// a handle for a path name. +func NameToHandleAt(dirfd int, path string, flags int) (handle FileHandle, mountID int, err error) { + var mid _C_int + // Try first with a small buffer, assuming the handle will + // only be 32 bytes. + size := uint32(32 + unsafe.Sizeof(fileHandle{})) + didResize := false + for { + buf := make([]byte, size) + fh := (*fileHandle)(unsafe.Pointer(&buf[0])) + fh.Bytes = size - uint32(unsafe.Sizeof(fileHandle{})) + err = nameToHandleAt(dirfd, path, fh, &mid, flags) + if err == EOVERFLOW { + if didResize { + // We shouldn't need to resize more than once + return + } + didResize = true + size = fh.Bytes + uint32(unsafe.Sizeof(fileHandle{})) + continue + } + if err != nil { + return + } + return FileHandle{fh}, int(mid), nil + } +} + +// OpenByHandleAt wraps the open_by_handle_at system call; it opens a +// file via a handle as previously returned by NameToHandleAt. +func OpenByHandleAt(mountFD int, handle FileHandle, flags int) (fd int, err error) { + return openByHandleAt(mountFD, handle.fileHandle, flags) +} + +// Klogset wraps the sys_syslog system call; it sets console_loglevel to +// the value specified by arg and passes a dummy pointer to bufp. +func Klogset(typ int, arg int) (err error) { + var p unsafe.Pointer + _, _, errno := Syscall(SYS_SYSLOG, uintptr(typ), uintptr(p), uintptr(arg)) + if errno != 0 { + return errnoErr(errno) + } + return nil +} + +// RemoteIovec is Iovec with the pointer replaced with an integer. +// It is used for ProcessVMReadv and ProcessVMWritev, where the pointer +// refers to a location in a different process' address space, which +// would confuse the Go garbage collector. +type RemoteIovec struct { + Base uintptr + Len int +} + +//sys ProcessVMReadv(pid int, localIov []Iovec, remoteIov []RemoteIovec, flags uint) (n int, err error) = SYS_PROCESS_VM_READV +//sys ProcessVMWritev(pid int, localIov []Iovec, remoteIov []RemoteIovec, flags uint) (n int, err error) = SYS_PROCESS_VM_WRITEV + +//sys PidfdOpen(pid int, flags int) (fd int, err error) = SYS_PIDFD_OPEN +//sys PidfdGetfd(pidfd int, targetfd int, flags int) (fd int, err error) = SYS_PIDFD_GETFD +//sys PidfdSendSignal(pidfd int, sig Signal, info *Siginfo, flags int) (err error) = SYS_PIDFD_SEND_SIGNAL + +//sys shmat(id int, addr uintptr, flag int) (ret uintptr, err error) +//sys shmctl(id int, cmd int, buf *SysvShmDesc) (result int, err error) +//sys shmdt(addr uintptr) (err error) +//sys shmget(key int, size int, flag int) (id int, err error) + +//sys getitimer(which int, currValue *Itimerval) (err error) +//sys setitimer(which int, newValue *Itimerval, oldValue *Itimerval) (err error) + +// MakeItimerval creates an Itimerval from interval and value durations. +func MakeItimerval(interval, value time.Duration) Itimerval { + return Itimerval{ + Interval: NsecToTimeval(interval.Nanoseconds()), + Value: NsecToTimeval(value.Nanoseconds()), + } +} + +// A value which may be passed to the which parameter for Getitimer and +// Setitimer. +type ItimerWhich int + +// Possible which values for Getitimer and Setitimer. +const ( + ItimerReal ItimerWhich = ITIMER_REAL + ItimerVirtual ItimerWhich = ITIMER_VIRTUAL + ItimerProf ItimerWhich = ITIMER_PROF +) + +// Getitimer wraps getitimer(2) to return the current value of the timer +// specified by which. +func Getitimer(which ItimerWhich) (Itimerval, error) { + var it Itimerval + if err := getitimer(int(which), &it); err != nil { + return Itimerval{}, err + } + + return it, nil +} + +// Setitimer wraps setitimer(2) to arm or disarm the timer specified by which. +// It returns the previous value of the timer. +// +// If the Itimerval argument is the zero value, the timer will be disarmed. +func Setitimer(which ItimerWhich, it Itimerval) (Itimerval, error) { + var prev Itimerval + if err := setitimer(int(which), &it, &prev); err != nil { + return Itimerval{}, err + } + + return prev, nil +} + +//sysnb rtSigprocmask(how int, set *Sigset_t, oldset *Sigset_t, sigsetsize uintptr) (err error) = SYS_RT_SIGPROCMASK + +func PthreadSigmask(how int, set, oldset *Sigset_t) error { + if oldset != nil { + // Explicitly clear in case Sigset_t is larger than _C__NSIG. + *oldset = Sigset_t{} + } + return rtSigprocmask(how, set, oldset, _C__NSIG/8) +} + +//sysnb getresuid(ruid *_C_int, euid *_C_int, suid *_C_int) +//sysnb getresgid(rgid *_C_int, egid *_C_int, sgid *_C_int) + +func Getresuid() (ruid, euid, suid int) { + var r, e, s _C_int + getresuid(&r, &e, &s) + return int(r), int(e), int(s) +} + +func Getresgid() (rgid, egid, sgid int) { + var r, e, s _C_int + getresgid(&r, &e, &s) + return int(r), int(e), int(s) +} + +// Pselect is a wrapper around the Linux pselect6 system call. +// This version does not modify the timeout argument. +func Pselect(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timespec, sigmask *Sigset_t) (n int, err error) { + // Per https://man7.org/linux/man-pages/man2/select.2.html#NOTES, + // The Linux pselect6() system call modifies its timeout argument. + // [Not modifying the argument] is the behavior required by POSIX.1-2001. + var mutableTimeout *Timespec + if timeout != nil { + mutableTimeout = new(Timespec) + *mutableTimeout = *timeout + } + + // The final argument of the pselect6() system call is not a + // sigset_t * pointer, but is instead a structure + var kernelMask *sigset_argpack + if sigmask != nil { + wordBits := 32 << (^uintptr(0) >> 63) // see math.intSize + + // A sigset stores one bit per signal, + // offset by 1 (because signal 0 does not exist). + // So the number of words needed is ⌈__C_NSIG - 1 / wordBits⌉. + sigsetWords := (_C__NSIG - 1 + wordBits - 1) / (wordBits) + + sigsetBytes := uintptr(sigsetWords * (wordBits / 8)) + kernelMask = &sigset_argpack{ + ss: sigmask, + ssLen: sigsetBytes, + } + } + + return pselect6(nfd, r, w, e, mutableTimeout, kernelMask) +} + +//sys schedSetattr(pid int, attr *SchedAttr, flags uint) (err error) +//sys schedGetattr(pid int, attr *SchedAttr, size uint, flags uint) (err error) + +// SchedSetAttr is a wrapper for sched_setattr(2) syscall. +// https://man7.org/linux/man-pages/man2/sched_setattr.2.html +func SchedSetAttr(pid int, attr *SchedAttr, flags uint) error { + if attr == nil { + return EINVAL + } + attr.Size = SizeofSchedAttr + return schedSetattr(pid, attr, flags) +} + +// SchedGetAttr is a wrapper for sched_getattr(2) syscall. +// https://man7.org/linux/man-pages/man2/sched_getattr.2.html +func SchedGetAttr(pid int, flags uint) (*SchedAttr, error) { + attr := &SchedAttr{} + if err := schedGetattr(pid, attr, SizeofSchedAttr, flags); err != nil { + return nil, err + } + return attr, nil +} diff --git a/vendor/golang.org/x/sys/unix/syscall_netbsd.go b/vendor/golang.org/x/sys/unix/syscall_netbsd.go new file mode 100644 index 0000000000..88162099af --- /dev/null +++ b/vendor/golang.org/x/sys/unix/syscall_netbsd.go @@ -0,0 +1,371 @@ +// Copyright 2009,2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// NetBSD system calls. +// This file is compiled as ordinary Go code, +// but it is also input to mksyscall, +// which parses the //sys lines and generates system call stubs. +// Note that sometimes we use a lowercase //sys name and wrap +// it in our own nicer implementation, either here or in +// syscall_bsd.go or syscall_unix.go. + +package unix + +import ( + "syscall" + "unsafe" +) + +// SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets. +type SockaddrDatalink struct { + Len uint8 + Family uint8 + Index uint16 + Type uint8 + Nlen uint8 + Alen uint8 + Slen uint8 + Data [12]int8 + raw RawSockaddrDatalink +} + +func anyToSockaddrGOOS(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { + return nil, EAFNOSUPPORT +} + +func Syscall9(trap, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2 uintptr, err syscall.Errno) + +func sysctlNodes(mib []_C_int) (nodes []Sysctlnode, err error) { + var olen uintptr + + // Get a list of all sysctl nodes below the given MIB by performing + // a sysctl for the given MIB with CTL_QUERY appended. + mib = append(mib, CTL_QUERY) + qnode := Sysctlnode{Flags: SYSCTL_VERS_1} + qp := (*byte)(unsafe.Pointer(&qnode)) + sz := unsafe.Sizeof(qnode) + if err = sysctl(mib, nil, &olen, qp, sz); err != nil { + return nil, err + } + + // Now that we know the size, get the actual nodes. + nodes = make([]Sysctlnode, olen/sz) + np := (*byte)(unsafe.Pointer(&nodes[0])) + if err = sysctl(mib, np, &olen, qp, sz); err != nil { + return nil, err + } + + return nodes, nil +} + +func nametomib(name string) (mib []_C_int, err error) { + // Split name into components. + var parts []string + last := 0 + for i := 0; i < len(name); i++ { + if name[i] == '.' { + parts = append(parts, name[last:i]) + last = i + 1 + } + } + parts = append(parts, name[last:]) + + // Discover the nodes and construct the MIB OID. + for partno, part := range parts { + nodes, err := sysctlNodes(mib) + if err != nil { + return nil, err + } + for _, node := range nodes { + n := make([]byte, 0) + for i := range node.Name { + if node.Name[i] != 0 { + n = append(n, byte(node.Name[i])) + } + } + if string(n) == part { + mib = append(mib, _C_int(node.Num)) + break + } + } + if len(mib) != partno+1 { + return nil, EINVAL + } + } + + return mib, nil +} + +func direntIno(buf []byte) (uint64, bool) { + return readInt(buf, unsafe.Offsetof(Dirent{}.Fileno), unsafe.Sizeof(Dirent{}.Fileno)) +} + +func direntReclen(buf []byte) (uint64, bool) { + return readInt(buf, unsafe.Offsetof(Dirent{}.Reclen), unsafe.Sizeof(Dirent{}.Reclen)) +} + +func direntNamlen(buf []byte) (uint64, bool) { + return readInt(buf, unsafe.Offsetof(Dirent{}.Namlen), unsafe.Sizeof(Dirent{}.Namlen)) +} + +func SysctlUvmexp(name string) (*Uvmexp, error) { + mib, err := sysctlmib(name) + if err != nil { + return nil, err + } + + n := uintptr(SizeofUvmexp) + var u Uvmexp + if err := sysctl(mib, (*byte)(unsafe.Pointer(&u)), &n, nil, 0); err != nil { + return nil, err + } + return &u, nil +} + +func Pipe(p []int) (err error) { + return Pipe2(p, 0) +} + +//sysnb pipe2(p *[2]_C_int, flags int) (err error) + +func Pipe2(p []int, flags int) error { + if len(p) != 2 { + return EINVAL + } + var pp [2]_C_int + err := pipe2(&pp, flags) + if err == nil { + p[0] = int(pp[0]) + p[1] = int(pp[1]) + } + return err +} + +//sys Getdents(fd int, buf []byte) (n int, err error) + +func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) { + n, err = Getdents(fd, buf) + if err != nil || basep == nil { + return + } + + var off int64 + off, err = Seek(fd, 0, 1 /* SEEK_CUR */) + if err != nil { + *basep = ^uintptr(0) + return + } + *basep = uintptr(off) + if unsafe.Sizeof(*basep) == 8 { + return + } + if off>>32 != 0 { + // We can't stuff the offset back into a uintptr, so any + // future calls would be suspect. Generate an error. + // EIO is allowed by getdirentries. + err = EIO + } + return +} + +//sys Getcwd(buf []byte) (n int, err error) = SYS___GETCWD + +// TODO +func sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) { + return -1, ENOSYS +} + +//sys ioctl(fd int, req uint, arg uintptr) (err error) +//sys ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error) = SYS_IOCTL + +//sys sysctl(mib []_C_int, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) = SYS___SYSCTL + +func IoctlGetPtmget(fd int, req uint) (*Ptmget, error) { + var value Ptmget + err := ioctlPtr(fd, req, unsafe.Pointer(&value)) + return &value, err +} + +func Uname(uname *Utsname) error { + mib := []_C_int{CTL_KERN, KERN_OSTYPE} + n := unsafe.Sizeof(uname.Sysname) + if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil { + return err + } + + mib = []_C_int{CTL_KERN, KERN_HOSTNAME} + n = unsafe.Sizeof(uname.Nodename) + if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil { + return err + } + + mib = []_C_int{CTL_KERN, KERN_OSRELEASE} + n = unsafe.Sizeof(uname.Release) + if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil { + return err + } + + mib = []_C_int{CTL_KERN, KERN_VERSION} + n = unsafe.Sizeof(uname.Version) + if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil { + return err + } + + // The version might have newlines or tabs in it, convert them to + // spaces. + for i, b := range uname.Version { + if b == '\n' || b == '\t' { + if i == len(uname.Version)-1 { + uname.Version[i] = 0 + } else { + uname.Version[i] = ' ' + } + } + } + + mib = []_C_int{CTL_HW, HW_MACHINE} + n = unsafe.Sizeof(uname.Machine) + if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil { + return err + } + + return nil +} + +func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) { + if raceenabled { + raceReleaseMerge(unsafe.Pointer(&ioSync)) + } + return sendfile(outfd, infd, offset, count) +} + +func Fstatvfs(fd int, buf *Statvfs_t) (err error) { + return Fstatvfs1(fd, buf, ST_WAIT) +} + +func Statvfs(path string, buf *Statvfs_t) (err error) { + return Statvfs1(path, buf, ST_WAIT) +} + +/* + * Exposed directly + */ +//sys Access(path string, mode uint32) (err error) +//sys Adjtime(delta *Timeval, olddelta *Timeval) (err error) +//sys Chdir(path string) (err error) +//sys Chflags(path string, flags int) (err error) +//sys Chmod(path string, mode uint32) (err error) +//sys Chown(path string, uid int, gid int) (err error) +//sys Chroot(path string) (err error) +//sys ClockGettime(clockid int32, time *Timespec) (err error) +//sys Close(fd int) (err error) +//sys Dup(fd int) (nfd int, err error) +//sys Dup2(from int, to int) (err error) +//sys Dup3(from int, to int, flags int) (err error) +//sys Exit(code int) +//sys ExtattrGetFd(fd int, attrnamespace int, attrname string, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrSetFd(fd int, attrnamespace int, attrname string, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrDeleteFd(fd int, attrnamespace int, attrname string) (err error) +//sys ExtattrListFd(fd int, attrnamespace int, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrGetFile(file string, attrnamespace int, attrname string, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrSetFile(file string, attrnamespace int, attrname string, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrDeleteFile(file string, attrnamespace int, attrname string) (err error) +//sys ExtattrListFile(file string, attrnamespace int, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrGetLink(link string, attrnamespace int, attrname string, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrSetLink(link string, attrnamespace int, attrname string, data uintptr, nbytes int) (ret int, err error) +//sys ExtattrDeleteLink(link string, attrnamespace int, attrname string) (err error) +//sys ExtattrListLink(link string, attrnamespace int, data uintptr, nbytes int) (ret int, err error) +//sys Faccessat(dirfd int, path string, mode uint32, flags int) (err error) +//sys Fadvise(fd int, offset int64, length int64, advice int) (err error) = SYS_POSIX_FADVISE +//sys Fchdir(fd int) (err error) +//sys Fchflags(fd int, flags int) (err error) +//sys Fchmod(fd int, mode uint32) (err error) +//sys Fchmodat(dirfd int, path string, mode uint32, flags int) (err error) +//sys Fchown(fd int, uid int, gid int) (err error) +//sys Fchownat(dirfd int, path string, uid int, gid int, flags int) (err error) +//sys Flock(fd int, how int) (err error) +//sys Fpathconf(fd int, name int) (val int, err error) +//sys Fstat(fd int, stat *Stat_t) (err error) +//sys Fstatat(fd int, path string, stat *Stat_t, flags int) (err error) +//sys Fstatvfs1(fd int, buf *Statvfs_t, flags int) (err error) = SYS_FSTATVFS1 +//sys Fsync(fd int) (err error) +//sys Ftruncate(fd int, length int64) (err error) +//sysnb Getegid() (egid int) +//sysnb Geteuid() (uid int) +//sysnb Getgid() (gid int) +//sysnb Getpgid(pid int) (pgid int, err error) +//sysnb Getpgrp() (pgrp int) +//sysnb Getpid() (pid int) +//sysnb Getppid() (ppid int) +//sys Getpriority(which int, who int) (prio int, err error) +//sysnb Getrlimit(which int, lim *Rlimit) (err error) +//sysnb Getrusage(who int, rusage *Rusage) (err error) +//sysnb Getsid(pid int) (sid int, err error) +//sysnb Gettimeofday(tv *Timeval) (err error) +//sysnb Getuid() (uid int) +//sys Issetugid() (tainted bool) +//sys Kill(pid int, signum syscall.Signal) (err error) +//sys Kqueue() (fd int, err error) +//sys Lchown(path string, uid int, gid int) (err error) +//sys Link(path string, link string) (err error) +//sys Linkat(pathfd int, path string, linkfd int, link string, flags int) (err error) +//sys Listen(s int, backlog int) (err error) +//sys Lstat(path string, stat *Stat_t) (err error) +//sys Mkdir(path string, mode uint32) (err error) +//sys Mkdirat(dirfd int, path string, mode uint32) (err error) +//sys Mkfifo(path string, mode uint32) (err error) +//sys Mkfifoat(dirfd int, path string, mode uint32) (err error) +//sys Mknod(path string, mode uint32, dev int) (err error) +//sys Mknodat(dirfd int, path string, mode uint32, dev int) (err error) +//sys Nanosleep(time *Timespec, leftover *Timespec) (err error) +//sys Open(path string, mode int, perm uint32) (fd int, err error) +//sys Openat(dirfd int, path string, mode int, perm uint32) (fd int, err error) +//sys Pathconf(path string, name int) (val int, err error) +//sys pread(fd int, p []byte, offset int64) (n int, err error) +//sys pwrite(fd int, p []byte, offset int64) (n int, err error) +//sys read(fd int, p []byte) (n int, err error) +//sys Readlink(path string, buf []byte) (n int, err error) +//sys Readlinkat(dirfd int, path string, buf []byte) (n int, err error) +//sys Rename(from string, to string) (err error) +//sys Renameat(fromfd int, from string, tofd int, to string) (err error) +//sys Revoke(path string) (err error) +//sys Rmdir(path string) (err error) +//sys Seek(fd int, offset int64, whence int) (newoffset int64, err error) = SYS_LSEEK +//sys Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error) +//sysnb Setegid(egid int) (err error) +//sysnb Seteuid(euid int) (err error) +//sysnb Setgid(gid int) (err error) +//sysnb Setpgid(pid int, pgid int) (err error) +//sys Setpriority(which int, who int, prio int) (err error) +//sysnb Setregid(rgid int, egid int) (err error) +//sysnb Setreuid(ruid int, euid int) (err error) +//sysnb Setsid() (pid int, err error) +//sysnb Settimeofday(tp *Timeval) (err error) +//sysnb Setuid(uid int) (err error) +//sys Stat(path string, stat *Stat_t) (err error) +//sys Statvfs1(path string, buf *Statvfs_t, flags int) (err error) = SYS_STATVFS1 +//sys Symlink(path string, link string) (err error) +//sys Symlinkat(oldpath string, newdirfd int, newpath string) (err error) +//sys Sync() (err error) +//sys Truncate(path string, length int64) (err error) +//sys Umask(newmask int) (oldmask int) +//sys Unlink(path string) (err error) +//sys Unlinkat(dirfd int, path string, flags int) (err error) +//sys Unmount(path string, flags int) (err error) +//sys write(fd int, p []byte) (n int, err error) +//sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error) +//sys munmap(addr uintptr, length uintptr) (err error) +//sys utimensat(dirfd int, path string, times *[2]Timespec, flags int) (err error) + +const ( + mremapFixed = MAP_FIXED + mremapDontunmap = 0 + mremapMaymove = 0 +) + +//sys mremapNetBSD(oldp uintptr, oldsize uintptr, newp uintptr, newsize uintptr, flags int) (xaddr uintptr, err error) = SYS_MREMAP + +func mremap(oldaddr uintptr, oldlength uintptr, newlength uintptr, flags int, newaddr uintptr) (uintptr, error) { + return mremapNetBSD(oldaddr, oldlength, newaddr, newlength, flags) +} diff --git a/vendor/golang.org/x/sys/unix/syscall_solaris.go b/vendor/golang.org/x/sys/unix/syscall_solaris.go new file mode 100644 index 0000000000..b99cfa1342 --- /dev/null +++ b/vendor/golang.org/x/sys/unix/syscall_solaris.go @@ -0,0 +1,1103 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Solaris system calls. +// This file is compiled as ordinary Go code, +// but it is also input to mksyscall, +// which parses the //sys lines and generates system call stubs. +// Note that sometimes we use a lowercase //sys name and wrap +// it in our own nicer implementation, either here or in +// syscall_solaris.go or syscall_unix.go. + +package unix + +import ( + "fmt" + "os" + "runtime" + "sync" + "syscall" + "unsafe" +) + +// Implemented in runtime/syscall_solaris.go. +type syscallFunc uintptr + +func rawSysvicall6(trap, nargs, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.Errno) +func sysvicall6(trap, nargs, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2 uintptr, err syscall.Errno) + +// SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets. +type SockaddrDatalink struct { + Family uint16 + Index uint16 + Type uint8 + Nlen uint8 + Alen uint8 + Slen uint8 + Data [244]int8 + raw RawSockaddrDatalink +} + +func direntIno(buf []byte) (uint64, bool) { + return readInt(buf, unsafe.Offsetof(Dirent{}.Ino), unsafe.Sizeof(Dirent{}.Ino)) +} + +func direntReclen(buf []byte) (uint64, bool) { + return readInt(buf, unsafe.Offsetof(Dirent{}.Reclen), unsafe.Sizeof(Dirent{}.Reclen)) +} + +func direntNamlen(buf []byte) (uint64, bool) { + reclen, ok := direntReclen(buf) + if !ok { + return 0, false + } + return reclen - uint64(unsafe.Offsetof(Dirent{}.Name)), true +} + +//sysnb pipe(p *[2]_C_int) (n int, err error) + +func Pipe(p []int) (err error) { + if len(p) != 2 { + return EINVAL + } + var pp [2]_C_int + n, err := pipe(&pp) + if n != 0 { + return err + } + if err == nil { + p[0] = int(pp[0]) + p[1] = int(pp[1]) + } + return nil +} + +//sysnb pipe2(p *[2]_C_int, flags int) (err error) + +func Pipe2(p []int, flags int) error { + if len(p) != 2 { + return EINVAL + } + var pp [2]_C_int + err := pipe2(&pp, flags) + if err == nil { + p[0] = int(pp[0]) + p[1] = int(pp[1]) + } + return err +} + +func (sa *SockaddrInet4) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Port < 0 || sa.Port > 0xFFFF { + return nil, 0, EINVAL + } + sa.raw.Family = AF_INET + p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port)) + p[0] = byte(sa.Port >> 8) + p[1] = byte(sa.Port) + sa.raw.Addr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrInet4, nil +} + +func (sa *SockaddrInet6) sockaddr() (unsafe.Pointer, _Socklen, error) { + if sa.Port < 0 || sa.Port > 0xFFFF { + return nil, 0, EINVAL + } + sa.raw.Family = AF_INET6 + p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port)) + p[0] = byte(sa.Port >> 8) + p[1] = byte(sa.Port) + sa.raw.Scope_id = sa.ZoneId + sa.raw.Addr = sa.Addr + return unsafe.Pointer(&sa.raw), SizeofSockaddrInet6, nil +} + +func (sa *SockaddrUnix) sockaddr() (unsafe.Pointer, _Socklen, error) { + name := sa.Name + n := len(name) + if n >= len(sa.raw.Path) { + return nil, 0, EINVAL + } + sa.raw.Family = AF_UNIX + for i := 0; i < n; i++ { + sa.raw.Path[i] = int8(name[i]) + } + // length is family (uint16), name, NUL. + sl := _Socklen(2) + if n > 0 { + sl += _Socklen(n) + 1 + } + if sa.raw.Path[0] == '@' { + sa.raw.Path[0] = 0 + // Don't count trailing NUL for abstract address. + sl-- + } + + return unsafe.Pointer(&sa.raw), sl, nil +} + +//sys getsockname(fd int, rsa *RawSockaddrAny, addrlen *_Socklen) (err error) = libsocket.getsockname + +func Getsockname(fd int) (sa Sockaddr, err error) { + var rsa RawSockaddrAny + var len _Socklen = SizeofSockaddrAny + if err = getsockname(fd, &rsa, &len); err != nil { + return + } + return anyToSockaddr(fd, &rsa) +} + +// GetsockoptString returns the string value of the socket option opt for the +// socket associated with fd at the given socket level. +func GetsockoptString(fd, level, opt int) (string, error) { + buf := make([]byte, 256) + vallen := _Socklen(len(buf)) + err := getsockopt(fd, level, opt, unsafe.Pointer(&buf[0]), &vallen) + if err != nil { + return "", err + } + return string(buf[:vallen-1]), nil +} + +const ImplementsGetwd = true + +//sys Getcwd(buf []byte) (n int, err error) + +func Getwd() (wd string, err error) { + var buf [PathMax]byte + // Getcwd will return an error if it failed for any reason. + _, err = Getcwd(buf[0:]) + if err != nil { + return "", err + } + n := clen(buf[:]) + if n < 1 { + return "", EINVAL + } + return string(buf[:n]), nil +} + +/* + * Wrapped + */ + +//sysnb getgroups(ngid int, gid *_Gid_t) (n int, err error) +//sysnb setgroups(ngid int, gid *_Gid_t) (err error) + +func Getgroups() (gids []int, err error) { + n, err := getgroups(0, nil) + // Check for error and sanity check group count. Newer versions of + // Solaris allow up to 1024 (NGROUPS_MAX). + if n < 0 || n > 1024 { + if err != nil { + return nil, err + } + return nil, EINVAL + } else if n == 0 { + return nil, nil + } + + a := make([]_Gid_t, n) + n, err = getgroups(n, &a[0]) + if n == -1 { + return nil, err + } + gids = make([]int, n) + for i, v := range a[0:n] { + gids[i] = int(v) + } + return +} + +func Setgroups(gids []int) (err error) { + if len(gids) == 0 { + return setgroups(0, nil) + } + + a := make([]_Gid_t, len(gids)) + for i, v := range gids { + a[i] = _Gid_t(v) + } + return setgroups(len(a), &a[0]) +} + +// ReadDirent reads directory entries from fd and writes them into buf. +func ReadDirent(fd int, buf []byte) (n int, err error) { + // Final argument is (basep *uintptr) and the syscall doesn't take nil. + // TODO(rsc): Can we use a single global basep for all calls? + return Getdents(fd, buf, new(uintptr)) +} + +// Wait status is 7 bits at bottom, either 0 (exited), +// 0x7F (stopped), or a signal number that caused an exit. +// The 0x80 bit is whether there was a core dump. +// An extra number (exit code, signal causing a stop) +// is in the high bits. + +type WaitStatus uint32 + +const ( + mask = 0x7F + core = 0x80 + shift = 8 + + exited = 0 + stopped = 0x7F +) + +func (w WaitStatus) Exited() bool { return w&mask == exited } + +func (w WaitStatus) ExitStatus() int { + if w&mask != exited { + return -1 + } + return int(w >> shift) +} + +func (w WaitStatus) Signaled() bool { return w&mask != stopped && w&mask != 0 } + +func (w WaitStatus) Signal() syscall.Signal { + sig := syscall.Signal(w & mask) + if sig == stopped || sig == 0 { + return -1 + } + return sig +} + +func (w WaitStatus) CoreDump() bool { return w.Signaled() && w&core != 0 } + +func (w WaitStatus) Stopped() bool { return w&mask == stopped && syscall.Signal(w>>shift) != SIGSTOP } + +func (w WaitStatus) Continued() bool { return w&mask == stopped && syscall.Signal(w>>shift) == SIGSTOP } + +func (w WaitStatus) StopSignal() syscall.Signal { + if !w.Stopped() { + return -1 + } + return syscall.Signal(w>>shift) & 0xFF +} + +func (w WaitStatus) TrapCause() int { return -1 } + +//sys wait4(pid int32, statusp *_C_int, options int, rusage *Rusage) (wpid int32, err error) + +func Wait4(pid int, wstatus *WaitStatus, options int, rusage *Rusage) (int, error) { + var status _C_int + rpid, err := wait4(int32(pid), &status, options, rusage) + wpid := int(rpid) + if wpid == -1 { + return wpid, err + } + if wstatus != nil { + *wstatus = WaitStatus(status) + } + return wpid, nil +} + +//sys gethostname(buf []byte) (n int, err error) + +func Gethostname() (name string, err error) { + var buf [MaxHostNameLen]byte + n, err := gethostname(buf[:]) + if n != 0 { + return "", err + } + n = clen(buf[:]) + if n < 1 { + return "", EFAULT + } + return string(buf[:n]), nil +} + +//sys utimes(path string, times *[2]Timeval) (err error) + +func Utimes(path string, tv []Timeval) (err error) { + if tv == nil { + return utimes(path, nil) + } + if len(tv) != 2 { + return EINVAL + } + return utimes(path, (*[2]Timeval)(unsafe.Pointer(&tv[0]))) +} + +//sys utimensat(fd int, path string, times *[2]Timespec, flag int) (err error) + +func UtimesNano(path string, ts []Timespec) error { + if ts == nil { + return utimensat(AT_FDCWD, path, nil, 0) + } + if len(ts) != 2 { + return EINVAL + } + return utimensat(AT_FDCWD, path, (*[2]Timespec)(unsafe.Pointer(&ts[0])), 0) +} + +func UtimesNanoAt(dirfd int, path string, ts []Timespec, flags int) error { + if ts == nil { + return utimensat(dirfd, path, nil, flags) + } + if len(ts) != 2 { + return EINVAL + } + return utimensat(dirfd, path, (*[2]Timespec)(unsafe.Pointer(&ts[0])), flags) +} + +//sys fcntl(fd int, cmd int, arg int) (val int, err error) + +// FcntlInt performs a fcntl syscall on fd with the provided command and argument. +func FcntlInt(fd uintptr, cmd, arg int) (int, error) { + valptr, _, errno := sysvicall6(uintptr(unsafe.Pointer(&procfcntl)), 3, uintptr(fd), uintptr(cmd), uintptr(arg), 0, 0, 0) + var err error + if errno != 0 { + err = errno + } + return int(valptr), err +} + +// FcntlFlock performs a fcntl syscall for the F_GETLK, F_SETLK or F_SETLKW command. +func FcntlFlock(fd uintptr, cmd int, lk *Flock_t) error { + _, _, e1 := sysvicall6(uintptr(unsafe.Pointer(&procfcntl)), 3, uintptr(fd), uintptr(cmd), uintptr(unsafe.Pointer(lk)), 0, 0, 0) + if e1 != 0 { + return e1 + } + return nil +} + +//sys futimesat(fildes int, path *byte, times *[2]Timeval) (err error) + +func Futimesat(dirfd int, path string, tv []Timeval) error { + pathp, err := BytePtrFromString(path) + if err != nil { + return err + } + if tv == nil { + return futimesat(dirfd, pathp, nil) + } + if len(tv) != 2 { + return EINVAL + } + return futimesat(dirfd, pathp, (*[2]Timeval)(unsafe.Pointer(&tv[0]))) +} + +// Solaris doesn't have an futimes function because it allows NULL to be +// specified as the path for futimesat. However, Go doesn't like +// NULL-style string interfaces, so this simple wrapper is provided. +func Futimes(fd int, tv []Timeval) error { + if tv == nil { + return futimesat(fd, nil, nil) + } + if len(tv) != 2 { + return EINVAL + } + return futimesat(fd, nil, (*[2]Timeval)(unsafe.Pointer(&tv[0]))) +} + +func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { + switch rsa.Addr.Family { + case AF_UNIX: + pp := (*RawSockaddrUnix)(unsafe.Pointer(rsa)) + sa := new(SockaddrUnix) + // Assume path ends at NUL. + // This is not technically the Solaris semantics for + // abstract Unix domain sockets -- they are supposed + // to be uninterpreted fixed-size binary blobs -- but + // everyone uses this convention. + n := 0 + for n < len(pp.Path) && pp.Path[n] != 0 { + n++ + } + sa.Name = string(unsafe.Slice((*byte)(unsafe.Pointer(&pp.Path[0])), n)) + return sa, nil + + case AF_INET: + pp := (*RawSockaddrInet4)(unsafe.Pointer(rsa)) + sa := new(SockaddrInet4) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + sa.Port = int(p[0])<<8 + int(p[1]) + sa.Addr = pp.Addr + return sa, nil + + case AF_INET6: + pp := (*RawSockaddrInet6)(unsafe.Pointer(rsa)) + sa := new(SockaddrInet6) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + sa.Port = int(p[0])<<8 + int(p[1]) + sa.ZoneId = pp.Scope_id + sa.Addr = pp.Addr + return sa, nil + } + return nil, EAFNOSUPPORT +} + +//sys accept(s int, rsa *RawSockaddrAny, addrlen *_Socklen) (fd int, err error) = libsocket.accept + +func Accept(fd int) (nfd int, sa Sockaddr, err error) { + var rsa RawSockaddrAny + var len _Socklen = SizeofSockaddrAny + nfd, err = accept(fd, &rsa, &len) + if nfd == -1 { + return + } + sa, err = anyToSockaddr(fd, &rsa) + if err != nil { + Close(nfd) + nfd = 0 + } + return +} + +//sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_recvmsg + +func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { + var msg Msghdr + msg.Name = (*byte)(unsafe.Pointer(rsa)) + msg.Namelen = uint32(SizeofSockaddrAny) + var dummy byte + if len(oob) > 0 { + // receive at least one normal byte + if emptyIovecs(iov) { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] + } + msg.Accrightslen = int32(len(oob)) + } + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } + if n, err = recvmsg(fd, &msg, flags); n == -1 { + return + } + oobn = int(msg.Accrightslen) + return +} + +//sys sendmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_sendmsg + +func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { + var msg Msghdr + msg.Name = (*byte)(unsafe.Pointer(ptr)) + msg.Namelen = uint32(salen) + var dummy byte + var empty bool + if len(oob) > 0 { + // send at least one normal byte + empty = emptyIovecs(iov) + if empty { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] + } + msg.Accrightslen = int32(len(oob)) + } + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } + if n, err = sendmsg(fd, &msg, flags); err != nil { + return 0, err + } + if len(oob) > 0 && empty { + n = 0 + } + return n, nil +} + +//sys acct(path *byte) (err error) + +func Acct(path string) (err error) { + if len(path) == 0 { + // Assume caller wants to disable accounting. + return acct(nil) + } + + pathp, err := BytePtrFromString(path) + if err != nil { + return err + } + return acct(pathp) +} + +//sys __makedev(version int, major uint, minor uint) (val uint64) + +func Mkdev(major, minor uint32) uint64 { + return __makedev(NEWDEV, uint(major), uint(minor)) +} + +//sys __major(version int, dev uint64) (val uint) + +func Major(dev uint64) uint32 { + return uint32(__major(NEWDEV, dev)) +} + +//sys __minor(version int, dev uint64) (val uint) + +func Minor(dev uint64) uint32 { + return uint32(__minor(NEWDEV, dev)) +} + +/* + * Expose the ioctl function + */ + +//sys ioctlRet(fd int, req int, arg uintptr) (ret int, err error) = libc.ioctl +//sys ioctlPtrRet(fd int, req int, arg unsafe.Pointer) (ret int, err error) = libc.ioctl + +func ioctl(fd int, req int, arg uintptr) (err error) { + _, err = ioctlRet(fd, req, arg) + return err +} + +func ioctlPtr(fd int, req int, arg unsafe.Pointer) (err error) { + _, err = ioctlPtrRet(fd, req, arg) + return err +} + +func IoctlSetTermio(fd int, req int, value *Termio) error { + return ioctlPtr(fd, req, unsafe.Pointer(value)) +} + +func IoctlGetTermio(fd int, req int) (*Termio, error) { + var value Termio + err := ioctlPtr(fd, req, unsafe.Pointer(&value)) + return &value, err +} + +//sys poll(fds *PollFd, nfds int, timeout int) (n int, err error) + +func Poll(fds []PollFd, timeout int) (n int, err error) { + if len(fds) == 0 { + return poll(nil, 0, timeout) + } + return poll(&fds[0], len(fds), timeout) +} + +func Sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) { + if raceenabled { + raceReleaseMerge(unsafe.Pointer(&ioSync)) + } + return sendfile(outfd, infd, offset, count) +} + +/* + * Exposed directly + */ +//sys Access(path string, mode uint32) (err error) +//sys Adjtime(delta *Timeval, olddelta *Timeval) (err error) +//sys Chdir(path string) (err error) +//sys Chmod(path string, mode uint32) (err error) +//sys Chown(path string, uid int, gid int) (err error) +//sys Chroot(path string) (err error) +//sys ClockGettime(clockid int32, time *Timespec) (err error) +//sys Close(fd int) (err error) +//sys Creat(path string, mode uint32) (fd int, err error) +//sys Dup(fd int) (nfd int, err error) +//sys Dup2(oldfd int, newfd int) (err error) +//sys Exit(code int) +//sys Faccessat(dirfd int, path string, mode uint32, flags int) (err error) +//sys Fchdir(fd int) (err error) +//sys Fchmod(fd int, mode uint32) (err error) +//sys Fchmodat(dirfd int, path string, mode uint32, flags int) (err error) +//sys Fchown(fd int, uid int, gid int) (err error) +//sys Fchownat(dirfd int, path string, uid int, gid int, flags int) (err error) +//sys Fdatasync(fd int) (err error) +//sys Flock(fd int, how int) (err error) +//sys Fpathconf(fd int, name int) (val int, err error) +//sys Fstat(fd int, stat *Stat_t) (err error) +//sys Fstatat(fd int, path string, stat *Stat_t, flags int) (err error) +//sys Fstatvfs(fd int, vfsstat *Statvfs_t) (err error) +//sys Getdents(fd int, buf []byte, basep *uintptr) (n int, err error) +//sysnb Getgid() (gid int) +//sysnb Getpid() (pid int) +//sysnb Getpgid(pid int) (pgid int, err error) +//sysnb Getpgrp() (pgid int, err error) +//sys Geteuid() (euid int) +//sys Getegid() (egid int) +//sys Getppid() (ppid int) +//sys Getpriority(which int, who int) (n int, err error) +//sysnb Getrlimit(which int, lim *Rlimit) (err error) +//sysnb Getrusage(who int, rusage *Rusage) (err error) +//sysnb Getsid(pid int) (sid int, err error) +//sysnb Gettimeofday(tv *Timeval) (err error) +//sysnb Getuid() (uid int) +//sys Kill(pid int, signum syscall.Signal) (err error) +//sys Lchown(path string, uid int, gid int) (err error) +//sys Link(path string, link string) (err error) +//sys Listen(s int, backlog int) (err error) = libsocket.__xnet_llisten +//sys Lstat(path string, stat *Stat_t) (err error) +//sys Madvise(b []byte, advice int) (err error) +//sys Mkdir(path string, mode uint32) (err error) +//sys Mkdirat(dirfd int, path string, mode uint32) (err error) +//sys Mkfifo(path string, mode uint32) (err error) +//sys Mkfifoat(dirfd int, path string, mode uint32) (err error) +//sys Mknod(path string, mode uint32, dev int) (err error) +//sys Mknodat(dirfd int, path string, mode uint32, dev int) (err error) +//sys Mlock(b []byte) (err error) +//sys Mlockall(flags int) (err error) +//sys Mprotect(b []byte, prot int) (err error) +//sys Msync(b []byte, flags int) (err error) +//sys Munlock(b []byte) (err error) +//sys Munlockall() (err error) +//sys Nanosleep(time *Timespec, leftover *Timespec) (err error) +//sys Open(path string, mode int, perm uint32) (fd int, err error) +//sys Openat(dirfd int, path string, flags int, mode uint32) (fd int, err error) +//sys Pathconf(path string, name int) (val int, err error) +//sys Pause() (err error) +//sys pread(fd int, p []byte, offset int64) (n int, err error) +//sys pwrite(fd int, p []byte, offset int64) (n int, err error) +//sys read(fd int, p []byte) (n int, err error) +//sys Readlink(path string, buf []byte) (n int, err error) +//sys Rename(from string, to string) (err error) +//sys Renameat(olddirfd int, oldpath string, newdirfd int, newpath string) (err error) +//sys Rmdir(path string) (err error) +//sys Seek(fd int, offset int64, whence int) (newoffset int64, err error) = lseek +//sys Select(nfd int, r *FdSet, w *FdSet, e *FdSet, timeout *Timeval) (n int, err error) +//sysnb Setegid(egid int) (err error) +//sysnb Seteuid(euid int) (err error) +//sysnb Setgid(gid int) (err error) +//sys Sethostname(p []byte) (err error) +//sysnb Setpgid(pid int, pgid int) (err error) +//sys Setpriority(which int, who int, prio int) (err error) +//sysnb Setregid(rgid int, egid int) (err error) +//sysnb Setreuid(ruid int, euid int) (err error) +//sysnb Setsid() (pid int, err error) +//sysnb Setuid(uid int) (err error) +//sys Shutdown(s int, how int) (err error) = libsocket.shutdown +//sys Stat(path string, stat *Stat_t) (err error) +//sys Statvfs(path string, vfsstat *Statvfs_t) (err error) +//sys Symlink(path string, link string) (err error) +//sys Sync() (err error) +//sys Sysconf(which int) (n int64, err error) +//sysnb Times(tms *Tms) (ticks uintptr, err error) +//sys Truncate(path string, length int64) (err error) +//sys Fsync(fd int) (err error) +//sys Ftruncate(fd int, length int64) (err error) +//sys Umask(mask int) (oldmask int) +//sysnb Uname(buf *Utsname) (err error) +//sys Unmount(target string, flags int) (err error) = libc.umount +//sys Unlink(path string) (err error) +//sys Unlinkat(dirfd int, path string, flags int) (err error) +//sys Ustat(dev int, ubuf *Ustat_t) (err error) +//sys Utime(path string, buf *Utimbuf) (err error) +//sys bind(s int, addr unsafe.Pointer, addrlen _Socklen) (err error) = libsocket.__xnet_bind +//sys connect(s int, addr unsafe.Pointer, addrlen _Socklen) (err error) = libsocket.__xnet_connect +//sys mmap(addr uintptr, length uintptr, prot int, flag int, fd int, pos int64) (ret uintptr, err error) +//sys munmap(addr uintptr, length uintptr) (err error) +//sys sendfile(outfd int, infd int, offset *int64, count int) (written int, err error) = libsendfile.sendfile +//sys sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error) = libsocket.__xnet_sendto +//sys socket(domain int, typ int, proto int) (fd int, err error) = libsocket.__xnet_socket +//sysnb socketpair(domain int, typ int, proto int, fd *[2]int32) (err error) = libsocket.__xnet_socketpair +//sys write(fd int, p []byte) (n int, err error) +//sys getsockopt(s int, level int, name int, val unsafe.Pointer, vallen *_Socklen) (err error) = libsocket.__xnet_getsockopt +//sysnb getpeername(fd int, rsa *RawSockaddrAny, addrlen *_Socklen) (err error) = libsocket.getpeername +//sys setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) (err error) = libsocket.setsockopt +//sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error) = libsocket.recvfrom + +// Event Ports + +type fileObjCookie struct { + fobj *fileObj + cookie interface{} +} + +// EventPort provides a safe abstraction on top of Solaris/illumos Event Ports. +type EventPort struct { + port int + mu sync.Mutex + fds map[uintptr]*fileObjCookie + paths map[string]*fileObjCookie + // The user cookie presents an interesting challenge from a memory management perspective. + // There are two paths by which we can discover that it is no longer in use: + // 1. The user calls port_dissociate before any events fire + // 2. An event fires and we return it to the user + // The tricky situation is if the event has fired in the kernel but + // the user hasn't requested/received it yet. + // If the user wants to port_dissociate before the event has been processed, + // we should handle things gracefully. To do so, we need to keep an extra + // reference to the cookie around until the event is processed + // thus the otherwise seemingly extraneous "cookies" map + // The key of this map is a pointer to the corresponding fCookie + cookies map[*fileObjCookie]struct{} +} + +// PortEvent is an abstraction of the port_event C struct. +// Compare Source against PORT_SOURCE_FILE or PORT_SOURCE_FD +// to see if Path or Fd was the event source. The other will be +// uninitialized. +type PortEvent struct { + Cookie interface{} + Events int32 + Fd uintptr + Path string + Source uint16 + fobj *fileObj +} + +// NewEventPort creates a new EventPort including the +// underlying call to port_create(3c). +func NewEventPort() (*EventPort, error) { + port, err := port_create() + if err != nil { + return nil, err + } + e := &EventPort{ + port: port, + fds: make(map[uintptr]*fileObjCookie), + paths: make(map[string]*fileObjCookie), + cookies: make(map[*fileObjCookie]struct{}), + } + return e, nil +} + +//sys port_create() (n int, err error) +//sys port_associate(port int, source int, object uintptr, events int, user *byte) (n int, err error) +//sys port_dissociate(port int, source int, object uintptr) (n int, err error) +//sys port_get(port int, pe *portEvent, timeout *Timespec) (n int, err error) +//sys port_getn(port int, pe *portEvent, max uint32, nget *uint32, timeout *Timespec) (n int, err error) + +// Close closes the event port. +func (e *EventPort) Close() error { + e.mu.Lock() + defer e.mu.Unlock() + err := Close(e.port) + if err != nil { + return err + } + e.fds = nil + e.paths = nil + e.cookies = nil + return nil +} + +// PathIsWatched checks to see if path is associated with this EventPort. +func (e *EventPort) PathIsWatched(path string) bool { + e.mu.Lock() + defer e.mu.Unlock() + _, found := e.paths[path] + return found +} + +// FdIsWatched checks to see if fd is associated with this EventPort. +func (e *EventPort) FdIsWatched(fd uintptr) bool { + e.mu.Lock() + defer e.mu.Unlock() + _, found := e.fds[fd] + return found +} + +// AssociatePath wraps port_associate(3c) for a filesystem path including +// creating the necessary file_obj from the provided stat information. +func (e *EventPort) AssociatePath(path string, stat os.FileInfo, events int, cookie interface{}) error { + e.mu.Lock() + defer e.mu.Unlock() + if _, found := e.paths[path]; found { + return fmt.Errorf("%v is already associated with this Event Port", path) + } + fCookie, err := createFileObjCookie(path, stat, cookie) + if err != nil { + return err + } + _, err = port_associate(e.port, PORT_SOURCE_FILE, uintptr(unsafe.Pointer(fCookie.fobj)), events, (*byte)(unsafe.Pointer(fCookie))) + if err != nil { + return err + } + e.paths[path] = fCookie + e.cookies[fCookie] = struct{}{} + return nil +} + +// DissociatePath wraps port_dissociate(3c) for a filesystem path. +func (e *EventPort) DissociatePath(path string) error { + e.mu.Lock() + defer e.mu.Unlock() + f, ok := e.paths[path] + if !ok { + return fmt.Errorf("%v is not associated with this Event Port", path) + } + _, err := port_dissociate(e.port, PORT_SOURCE_FILE, uintptr(unsafe.Pointer(f.fobj))) + // If the path is no longer associated with this event port (ENOENT) + // we should delete it from our map. We can still return ENOENT to the caller. + // But we need to save the cookie + if err != nil && err != ENOENT { + return err + } + if err == nil { + // dissociate was successful, safe to delete the cookie + fCookie := e.paths[path] + delete(e.cookies, fCookie) + } + delete(e.paths, path) + return err +} + +// AssociateFd wraps calls to port_associate(3c) on file descriptors. +func (e *EventPort) AssociateFd(fd uintptr, events int, cookie interface{}) error { + e.mu.Lock() + defer e.mu.Unlock() + if _, found := e.fds[fd]; found { + return fmt.Errorf("%v is already associated with this Event Port", fd) + } + fCookie, err := createFileObjCookie("", nil, cookie) + if err != nil { + return err + } + _, err = port_associate(e.port, PORT_SOURCE_FD, fd, events, (*byte)(unsafe.Pointer(fCookie))) + if err != nil { + return err + } + e.fds[fd] = fCookie + e.cookies[fCookie] = struct{}{} + return nil +} + +// DissociateFd wraps calls to port_dissociate(3c) on file descriptors. +func (e *EventPort) DissociateFd(fd uintptr) error { + e.mu.Lock() + defer e.mu.Unlock() + _, ok := e.fds[fd] + if !ok { + return fmt.Errorf("%v is not associated with this Event Port", fd) + } + _, err := port_dissociate(e.port, PORT_SOURCE_FD, fd) + if err != nil && err != ENOENT { + return err + } + if err == nil { + // dissociate was successful, safe to delete the cookie + fCookie := e.fds[fd] + delete(e.cookies, fCookie) + } + delete(e.fds, fd) + return err +} + +func createFileObjCookie(name string, stat os.FileInfo, cookie interface{}) (*fileObjCookie, error) { + fCookie := new(fileObjCookie) + fCookie.cookie = cookie + if name != "" && stat != nil { + fCookie.fobj = new(fileObj) + bs, err := ByteSliceFromString(name) + if err != nil { + return nil, err + } + fCookie.fobj.Name = (*int8)(unsafe.Pointer(&bs[0])) + s := stat.Sys().(*syscall.Stat_t) + fCookie.fobj.Atim.Sec = s.Atim.Sec + fCookie.fobj.Atim.Nsec = s.Atim.Nsec + fCookie.fobj.Mtim.Sec = s.Mtim.Sec + fCookie.fobj.Mtim.Nsec = s.Mtim.Nsec + fCookie.fobj.Ctim.Sec = s.Ctim.Sec + fCookie.fobj.Ctim.Nsec = s.Ctim.Nsec + } + return fCookie, nil +} + +// GetOne wraps port_get(3c) and returns a single PortEvent. +func (e *EventPort) GetOne(t *Timespec) (*PortEvent, error) { + pe := new(portEvent) + _, err := port_get(e.port, pe, t) + if err != nil { + return nil, err + } + p := new(PortEvent) + e.mu.Lock() + defer e.mu.Unlock() + err = e.peIntToExt(pe, p) + if err != nil { + return nil, err + } + return p, nil +} + +// peIntToExt converts a cgo portEvent struct into the friendlier PortEvent +// NOTE: Always call this function while holding the e.mu mutex +func (e *EventPort) peIntToExt(peInt *portEvent, peExt *PortEvent) error { + if e.cookies == nil { + return fmt.Errorf("this EventPort is already closed") + } + peExt.Events = peInt.Events + peExt.Source = peInt.Source + fCookie := (*fileObjCookie)(unsafe.Pointer(peInt.User)) + _, found := e.cookies[fCookie] + + if !found { + panic("unexpected event port address; may be due to kernel bug; see https://go.dev/issue/54254") + } + peExt.Cookie = fCookie.cookie + delete(e.cookies, fCookie) + + switch peInt.Source { + case PORT_SOURCE_FD: + peExt.Fd = uintptr(peInt.Object) + // Only remove the fds entry if it exists and this cookie matches + if fobj, ok := e.fds[peExt.Fd]; ok { + if fobj == fCookie { + delete(e.fds, peExt.Fd) + } + } + case PORT_SOURCE_FILE: + peExt.fobj = fCookie.fobj + peExt.Path = BytePtrToString((*byte)(unsafe.Pointer(peExt.fobj.Name))) + // Only remove the paths entry if it exists and this cookie matches + if fobj, ok := e.paths[peExt.Path]; ok { + if fobj == fCookie { + delete(e.paths, peExt.Path) + } + } + } + return nil +} + +// Pending wraps port_getn(3c) and returns how many events are pending. +func (e *EventPort) Pending() (int, error) { + var n uint32 = 0 + _, err := port_getn(e.port, nil, 0, &n, nil) + return int(n), err +} + +// Get wraps port_getn(3c) and fills a slice of PortEvent. +// It will block until either min events have been received +// or the timeout has been exceeded. It will return how many +// events were actually received along with any error information. +func (e *EventPort) Get(s []PortEvent, min int, timeout *Timespec) (int, error) { + if min == 0 { + return 0, fmt.Errorf("need to request at least one event or use Pending() instead") + } + if len(s) < min { + return 0, fmt.Errorf("len(s) (%d) is less than min events requested (%d)", len(s), min) + } + got := uint32(min) + max := uint32(len(s)) + var err error + ps := make([]portEvent, max) + _, err = port_getn(e.port, &ps[0], max, &got, timeout) + // got will be trustworthy with ETIME, but not any other error. + if err != nil && err != ETIME { + return 0, err + } + e.mu.Lock() + defer e.mu.Unlock() + valid := 0 + for i := 0; i < int(got); i++ { + err2 := e.peIntToExt(&ps[i], &s[i]) + if err2 != nil { + if valid == 0 && err == nil { + // If err2 is the only error and there are no valid events + // to return, return it to the caller. + err = err2 + } + break + } + valid = i + 1 + } + return valid, err +} + +//sys putmsg(fd int, clptr *strbuf, dataptr *strbuf, flags int) (err error) + +func Putmsg(fd int, cl []byte, data []byte, flags int) (err error) { + var clp, datap *strbuf + if len(cl) > 0 { + clp = &strbuf{ + Len: int32(len(cl)), + Buf: (*int8)(unsafe.Pointer(&cl[0])), + } + } + if len(data) > 0 { + datap = &strbuf{ + Len: int32(len(data)), + Buf: (*int8)(unsafe.Pointer(&data[0])), + } + } + return putmsg(fd, clp, datap, flags) +} + +//sys getmsg(fd int, clptr *strbuf, dataptr *strbuf, flags *int) (err error) + +func Getmsg(fd int, cl []byte, data []byte) (retCl []byte, retData []byte, flags int, err error) { + var clp, datap *strbuf + if len(cl) > 0 { + clp = &strbuf{ + Maxlen: int32(len(cl)), + Buf: (*int8)(unsafe.Pointer(&cl[0])), + } + } + if len(data) > 0 { + datap = &strbuf{ + Maxlen: int32(len(data)), + Buf: (*int8)(unsafe.Pointer(&data[0])), + } + } + + if err = getmsg(fd, clp, datap, &flags); err != nil { + return nil, nil, 0, err + } + + if len(cl) > 0 { + retCl = cl[:clp.Len] + } + if len(data) > 0 { + retData = data[:datap.Len] + } + return retCl, retData, flags, nil +} + +func IoctlSetIntRetInt(fd int, req int, arg int) (int, error) { + return ioctlRet(fd, req, uintptr(arg)) +} + +func IoctlSetString(fd int, req int, val string) error { + bs := make([]byte, len(val)+1) + copy(bs[:len(bs)-1], val) + err := ioctlPtr(fd, req, unsafe.Pointer(&bs[0])) + runtime.KeepAlive(&bs[0]) + return err +} + +// Lifreq Helpers + +func (l *Lifreq) SetName(name string) error { + if len(name) >= len(l.Name) { + return fmt.Errorf("name cannot be more than %d characters", len(l.Name)-1) + } + for i := range name { + l.Name[i] = int8(name[i]) + } + return nil +} + +func (l *Lifreq) SetLifruInt(d int) { + *(*int)(unsafe.Pointer(&l.Lifru[0])) = d +} + +func (l *Lifreq) GetLifruInt() int { + return *(*int)(unsafe.Pointer(&l.Lifru[0])) +} + +func (l *Lifreq) SetLifruUint(d uint) { + *(*uint)(unsafe.Pointer(&l.Lifru[0])) = d +} + +func (l *Lifreq) GetLifruUint() uint { + return *(*uint)(unsafe.Pointer(&l.Lifru[0])) +} + +func IoctlLifreq(fd int, req int, l *Lifreq) error { + return ioctlPtr(fd, req, unsafe.Pointer(l)) +} + +// Strioctl Helpers + +func (s *Strioctl) SetInt(i int) { + s.Len = int32(unsafe.Sizeof(i)) + s.Dp = (*int8)(unsafe.Pointer(&i)) +} + +func IoctlSetStrioctlRetInt(fd int, req int, s *Strioctl) (int, error) { + return ioctlPtrRet(fd, req, unsafe.Pointer(s)) +} diff --git a/vendor/golang.org/x/sys/unix/ztypes_linux_riscv64.go b/vendor/golang.org/x/sys/unix/ztypes_linux_riscv64.go new file mode 100644 index 0000000000..1b4c97c32a --- /dev/null +++ b/vendor/golang.org/x/sys/unix/ztypes_linux_riscv64.go @@ -0,0 +1,747 @@ +// cgo -godefs -objdir=/tmp/riscv64/cgo -- -Wall -Werror -static -I/tmp/riscv64/include linux/types.go | go run mkpost.go +// Code generated by the command above; see README.md. DO NOT EDIT. + +//go:build riscv64 && linux +// +build riscv64,linux + +package unix + +const ( + SizeofPtr = 0x8 + SizeofLong = 0x8 +) + +type ( + _C_long int64 +) + +type Timespec struct { + Sec int64 + Nsec int64 +} + +type Timeval struct { + Sec int64 + Usec int64 +} + +type Timex struct { + Modes uint32 + Offset int64 + Freq int64 + Maxerror int64 + Esterror int64 + Status int32 + Constant int64 + Precision int64 + Tolerance int64 + Time Timeval + Tick int64 + Ppsfreq int64 + Jitter int64 + Shift int32 + Stabil int64 + Jitcnt int64 + Calcnt int64 + Errcnt int64 + Stbcnt int64 + Tai int32 + _ [44]byte +} + +type Time_t int64 + +type Tms struct { + Utime int64 + Stime int64 + Cutime int64 + Cstime int64 +} + +type Utimbuf struct { + Actime int64 + Modtime int64 +} + +type Rusage struct { + Utime Timeval + Stime Timeval + Maxrss int64 + Ixrss int64 + Idrss int64 + Isrss int64 + Minflt int64 + Majflt int64 + Nswap int64 + Inblock int64 + Oublock int64 + Msgsnd int64 + Msgrcv int64 + Nsignals int64 + Nvcsw int64 + Nivcsw int64 +} + +type Stat_t struct { + Dev uint64 + Ino uint64 + Mode uint32 + Nlink uint32 + Uid uint32 + Gid uint32 + Rdev uint64 + _ uint64 + Size int64 + Blksize int32 + _ int32 + Blocks int64 + Atim Timespec + Mtim Timespec + Ctim Timespec + _ [2]int32 +} + +type Dirent struct { + Ino uint64 + Off int64 + Reclen uint16 + Type uint8 + Name [256]uint8 + _ [5]byte +} + +type Flock_t struct { + Type int16 + Whence int16 + Start int64 + Len int64 + Pid int32 + _ [4]byte +} + +type DmNameList struct { + Dev uint64 + Next uint32 + Name [0]byte + _ [4]byte +} + +const ( + FADV_DONTNEED = 0x4 + FADV_NOREUSE = 0x5 +) + +type RawSockaddrNFCLLCP struct { + Sa_family uint16 + Dev_idx uint32 + Target_idx uint32 + Nfc_protocol uint32 + Dsap uint8 + Ssap uint8 + Service_name [63]uint8 + Service_name_len uint64 +} + +type RawSockaddr struct { + Family uint16 + Data [14]uint8 +} + +type RawSockaddrAny struct { + Addr RawSockaddr + Pad [96]uint8 +} + +type Iovec struct { + Base *byte + Len uint64 +} + +type Msghdr struct { + Name *byte + Namelen uint32 + Iov *Iovec + Iovlen uint64 + Control *byte + Controllen uint64 + Flags int32 + _ [4]byte +} + +type Cmsghdr struct { + Len uint64 + Level int32 + Type int32 +} + +type ifreq struct { + Ifrn [16]byte + Ifru [24]byte +} + +const ( + SizeofSockaddrNFCLLCP = 0x60 + SizeofIovec = 0x10 + SizeofMsghdr = 0x38 + SizeofCmsghdr = 0x10 +) + +const ( + SizeofSockFprog = 0x10 +) + +type PtraceRegs struct { + Pc uint64 + Ra uint64 + Sp uint64 + Gp uint64 + Tp uint64 + T0 uint64 + T1 uint64 + T2 uint64 + S0 uint64 + S1 uint64 + A0 uint64 + A1 uint64 + A2 uint64 + A3 uint64 + A4 uint64 + A5 uint64 + A6 uint64 + A7 uint64 + S2 uint64 + S3 uint64 + S4 uint64 + S5 uint64 + S6 uint64 + S7 uint64 + S8 uint64 + S9 uint64 + S10 uint64 + S11 uint64 + T3 uint64 + T4 uint64 + T5 uint64 + T6 uint64 +} + +type FdSet struct { + Bits [16]int64 +} + +type Sysinfo_t struct { + Uptime int64 + Loads [3]uint64 + Totalram uint64 + Freeram uint64 + Sharedram uint64 + Bufferram uint64 + Totalswap uint64 + Freeswap uint64 + Procs uint16 + Pad uint16 + Totalhigh uint64 + Freehigh uint64 + Unit uint32 + _ [0]uint8 + _ [4]byte +} + +type Ustat_t struct { + Tfree int32 + Tinode uint64 + Fname [6]uint8 + Fpack [6]uint8 + _ [4]byte +} + +type EpollEvent struct { + Events uint32 + _ int32 + Fd int32 + Pad int32 +} + +const ( + OPEN_TREE_CLOEXEC = 0x80000 +) + +const ( + POLLRDHUP = 0x2000 +) + +type Sigset_t struct { + Val [16]uint64 +} + +const _C__NSIG = 0x41 + +const ( + SIG_BLOCK = 0x0 + SIG_UNBLOCK = 0x1 + SIG_SETMASK = 0x2 +) + +type Siginfo struct { + Signo int32 + Errno int32 + Code int32 + _ int32 + _ [112]byte +} + +type Termios struct { + Iflag uint32 + Oflag uint32 + Cflag uint32 + Lflag uint32 + Line uint8 + Cc [19]uint8 + Ispeed uint32 + Ospeed uint32 +} + +type Taskstats struct { + Version uint16 + Ac_exitcode uint32 + Ac_flag uint8 + Ac_nice uint8 + Cpu_count uint64 + Cpu_delay_total uint64 + Blkio_count uint64 + Blkio_delay_total uint64 + Swapin_count uint64 + Swapin_delay_total uint64 + Cpu_run_real_total uint64 + Cpu_run_virtual_total uint64 + Ac_comm [32]uint8 + Ac_sched uint8 + Ac_pad [3]uint8 + _ [4]byte + Ac_uid uint32 + Ac_gid uint32 + Ac_pid uint32 + Ac_ppid uint32 + Ac_btime uint32 + Ac_etime uint64 + Ac_utime uint64 + Ac_stime uint64 + Ac_minflt uint64 + Ac_majflt uint64 + Coremem uint64 + Virtmem uint64 + Hiwater_rss uint64 + Hiwater_vm uint64 + Read_char uint64 + Write_char uint64 + Read_syscalls uint64 + Write_syscalls uint64 + Read_bytes uint64 + Write_bytes uint64 + Cancelled_write_bytes uint64 + Nvcsw uint64 + Nivcsw uint64 + Ac_utimescaled uint64 + Ac_stimescaled uint64 + Cpu_scaled_run_real_total uint64 + Freepages_count uint64 + Freepages_delay_total uint64 + Thrashing_count uint64 + Thrashing_delay_total uint64 + Ac_btime64 uint64 + Compact_count uint64 + Compact_delay_total uint64 + Ac_tgid uint32 + Ac_tgetime uint64 + Ac_exe_dev uint64 + Ac_exe_inode uint64 + Wpcopy_count uint64 + Wpcopy_delay_total uint64 + Irq_count uint64 + Irq_delay_total uint64 +} + +type cpuMask uint64 + +const ( + _NCPUBITS = 0x40 +) + +const ( + CBitFieldMaskBit0 = 0x1 + CBitFieldMaskBit1 = 0x2 + CBitFieldMaskBit2 = 0x4 + CBitFieldMaskBit3 = 0x8 + CBitFieldMaskBit4 = 0x10 + CBitFieldMaskBit5 = 0x20 + CBitFieldMaskBit6 = 0x40 + CBitFieldMaskBit7 = 0x80 + CBitFieldMaskBit8 = 0x100 + CBitFieldMaskBit9 = 0x200 + CBitFieldMaskBit10 = 0x400 + CBitFieldMaskBit11 = 0x800 + CBitFieldMaskBit12 = 0x1000 + CBitFieldMaskBit13 = 0x2000 + CBitFieldMaskBit14 = 0x4000 + CBitFieldMaskBit15 = 0x8000 + CBitFieldMaskBit16 = 0x10000 + CBitFieldMaskBit17 = 0x20000 + CBitFieldMaskBit18 = 0x40000 + CBitFieldMaskBit19 = 0x80000 + CBitFieldMaskBit20 = 0x100000 + CBitFieldMaskBit21 = 0x200000 + CBitFieldMaskBit22 = 0x400000 + CBitFieldMaskBit23 = 0x800000 + CBitFieldMaskBit24 = 0x1000000 + CBitFieldMaskBit25 = 0x2000000 + CBitFieldMaskBit26 = 0x4000000 + CBitFieldMaskBit27 = 0x8000000 + CBitFieldMaskBit28 = 0x10000000 + CBitFieldMaskBit29 = 0x20000000 + CBitFieldMaskBit30 = 0x40000000 + CBitFieldMaskBit31 = 0x80000000 + CBitFieldMaskBit32 = 0x100000000 + CBitFieldMaskBit33 = 0x200000000 + CBitFieldMaskBit34 = 0x400000000 + CBitFieldMaskBit35 = 0x800000000 + CBitFieldMaskBit36 = 0x1000000000 + CBitFieldMaskBit37 = 0x2000000000 + CBitFieldMaskBit38 = 0x4000000000 + CBitFieldMaskBit39 = 0x8000000000 + CBitFieldMaskBit40 = 0x10000000000 + CBitFieldMaskBit41 = 0x20000000000 + CBitFieldMaskBit42 = 0x40000000000 + CBitFieldMaskBit43 = 0x80000000000 + CBitFieldMaskBit44 = 0x100000000000 + CBitFieldMaskBit45 = 0x200000000000 + CBitFieldMaskBit46 = 0x400000000000 + CBitFieldMaskBit47 = 0x800000000000 + CBitFieldMaskBit48 = 0x1000000000000 + CBitFieldMaskBit49 = 0x2000000000000 + CBitFieldMaskBit50 = 0x4000000000000 + CBitFieldMaskBit51 = 0x8000000000000 + CBitFieldMaskBit52 = 0x10000000000000 + CBitFieldMaskBit53 = 0x20000000000000 + CBitFieldMaskBit54 = 0x40000000000000 + CBitFieldMaskBit55 = 0x80000000000000 + CBitFieldMaskBit56 = 0x100000000000000 + CBitFieldMaskBit57 = 0x200000000000000 + CBitFieldMaskBit58 = 0x400000000000000 + CBitFieldMaskBit59 = 0x800000000000000 + CBitFieldMaskBit60 = 0x1000000000000000 + CBitFieldMaskBit61 = 0x2000000000000000 + CBitFieldMaskBit62 = 0x4000000000000000 + CBitFieldMaskBit63 = 0x8000000000000000 +) + +type SockaddrStorage struct { + Family uint16 + Data [118]byte + _ uint64 +} + +type HDGeometry struct { + Heads uint8 + Sectors uint8 + Cylinders uint16 + Start uint64 +} + +type Statfs_t struct { + Type int64 + Bsize int64 + Blocks uint64 + Bfree uint64 + Bavail uint64 + Files uint64 + Ffree uint64 + Fsid Fsid + Namelen int64 + Frsize int64 + Flags int64 + Spare [4]int64 +} + +type TpacketHdr struct { + Status uint64 + Len uint32 + Snaplen uint32 + Mac uint16 + Net uint16 + Sec uint32 + Usec uint32 + _ [4]byte +} + +const ( + SizeofTpacketHdr = 0x20 +) + +type RTCPLLInfo struct { + Ctrl int32 + Value int32 + Max int32 + Min int32 + Posmult int32 + Negmult int32 + Clock int64 +} + +type BlkpgPartition struct { + Start int64 + Length int64 + Pno int32 + Devname [64]uint8 + Volname [64]uint8 + _ [4]byte +} + +const ( + BLKPG = 0x1269 +) + +type XDPUmemReg struct { + Addr uint64 + Len uint64 + Size uint32 + Headroom uint32 + Flags uint32 + _ [4]byte +} + +type CryptoUserAlg struct { + Name [64]uint8 + Driver_name [64]uint8 + Module_name [64]uint8 + Type uint32 + Mask uint32 + Refcnt uint32 + Flags uint32 +} + +type CryptoStatAEAD struct { + Type [64]uint8 + Encrypt_cnt uint64 + Encrypt_tlen uint64 + Decrypt_cnt uint64 + Decrypt_tlen uint64 + Err_cnt uint64 +} + +type CryptoStatAKCipher struct { + Type [64]uint8 + Encrypt_cnt uint64 + Encrypt_tlen uint64 + Decrypt_cnt uint64 + Decrypt_tlen uint64 + Verify_cnt uint64 + Sign_cnt uint64 + Err_cnt uint64 +} + +type CryptoStatCipher struct { + Type [64]uint8 + Encrypt_cnt uint64 + Encrypt_tlen uint64 + Decrypt_cnt uint64 + Decrypt_tlen uint64 + Err_cnt uint64 +} + +type CryptoStatCompress struct { + Type [64]uint8 + Compress_cnt uint64 + Compress_tlen uint64 + Decompress_cnt uint64 + Decompress_tlen uint64 + Err_cnt uint64 +} + +type CryptoStatHash struct { + Type [64]uint8 + Hash_cnt uint64 + Hash_tlen uint64 + Err_cnt uint64 +} + +type CryptoStatKPP struct { + Type [64]uint8 + Setsecret_cnt uint64 + Generate_public_key_cnt uint64 + Compute_shared_secret_cnt uint64 + Err_cnt uint64 +} + +type CryptoStatRNG struct { + Type [64]uint8 + Generate_cnt uint64 + Generate_tlen uint64 + Seed_cnt uint64 + Err_cnt uint64 +} + +type CryptoStatLarval struct { + Type [64]uint8 +} + +type CryptoReportLarval struct { + Type [64]uint8 +} + +type CryptoReportHash struct { + Type [64]uint8 + Blocksize uint32 + Digestsize uint32 +} + +type CryptoReportCipher struct { + Type [64]uint8 + Blocksize uint32 + Min_keysize uint32 + Max_keysize uint32 +} + +type CryptoReportBlkCipher struct { + Type [64]uint8 + Geniv [64]uint8 + Blocksize uint32 + Min_keysize uint32 + Max_keysize uint32 + Ivsize uint32 +} + +type CryptoReportAEAD struct { + Type [64]uint8 + Geniv [64]uint8 + Blocksize uint32 + Maxauthsize uint32 + Ivsize uint32 +} + +type CryptoReportComp struct { + Type [64]uint8 +} + +type CryptoReportRNG struct { + Type [64]uint8 + Seedsize uint32 +} + +type CryptoReportAKCipher struct { + Type [64]uint8 +} + +type CryptoReportKPP struct { + Type [64]uint8 +} + +type CryptoReportAcomp struct { + Type [64]uint8 +} + +type LoopInfo struct { + Number int32 + Device uint32 + Inode uint64 + Rdevice uint32 + Offset int32 + Encrypt_type int32 + Encrypt_key_size int32 + Flags int32 + Name [64]uint8 + Encrypt_key [32]uint8 + Init [2]uint64 + Reserved [4]uint8 + _ [4]byte +} + +type TIPCSubscr struct { + Seq TIPCServiceRange + Timeout uint32 + Filter uint32 + Handle [8]uint8 +} + +type TIPCSIOCLNReq struct { + Peer uint32 + Id uint32 + Linkname [68]uint8 +} + +type TIPCSIOCNodeIDReq struct { + Peer uint32 + Id [16]uint8 +} + +type PPSKInfo struct { + Assert_sequence uint32 + Clear_sequence uint32 + Assert_tu PPSKTime + Clear_tu PPSKTime + Current_mode int32 + _ [4]byte +} + +const ( + PPS_GETPARAMS = 0x800870a1 + PPS_SETPARAMS = 0x400870a2 + PPS_GETCAP = 0x800870a3 + PPS_FETCH = 0xc00870a4 +) + +const ( + PIDFD_NONBLOCK = 0x800 +) + +type SysvIpcPerm struct { + Key int32 + Uid uint32 + Gid uint32 + Cuid uint32 + Cgid uint32 + Mode uint32 + _ [0]uint8 + Seq uint16 + _ uint16 + _ uint64 + _ uint64 +} +type SysvShmDesc struct { + Perm SysvIpcPerm + Segsz uint64 + Atime int64 + Dtime int64 + Ctime int64 + Cpid int32 + Lpid int32 + Nattch uint64 + _ uint64 + _ uint64 +} + +type RISCVHWProbePairs struct { + Key int64 + Value uint64 +} + +const ( + RISCV_HWPROBE_KEY_MVENDORID = 0x0 + RISCV_HWPROBE_KEY_MARCHID = 0x1 + RISCV_HWPROBE_KEY_MIMPID = 0x2 + RISCV_HWPROBE_KEY_BASE_BEHAVIOR = 0x3 + RISCV_HWPROBE_BASE_BEHAVIOR_IMA = 0x1 + RISCV_HWPROBE_KEY_IMA_EXT_0 = 0x4 + RISCV_HWPROBE_IMA_FD = 0x1 + RISCV_HWPROBE_IMA_C = 0x2 + RISCV_HWPROBE_IMA_V = 0x4 + RISCV_HWPROBE_EXT_ZBA = 0x8 + RISCV_HWPROBE_EXT_ZBB = 0x10 + RISCV_HWPROBE_EXT_ZBS = 0x20 + RISCV_HWPROBE_KEY_CPUPERF_0 = 0x5 + RISCV_HWPROBE_MISALIGNED_UNKNOWN = 0x0 + RISCV_HWPROBE_MISALIGNED_EMULATED = 0x1 + RISCV_HWPROBE_MISALIGNED_SLOW = 0x2 + RISCV_HWPROBE_MISALIGNED_FAST = 0x3 + RISCV_HWPROBE_MISALIGNED_UNSUPPORTED = 0x4 + RISCV_HWPROBE_MISALIGNED_MASK = 0x7 +) diff --git a/vendor/golang.org/x/tools/cmd/stringer/stringer.go b/vendor/golang.org/x/tools/cmd/stringer/stringer.go new file mode 100644 index 0000000000..998d1a51bf --- /dev/null +++ b/vendor/golang.org/x/tools/cmd/stringer/stringer.go @@ -0,0 +1,657 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Stringer is a tool to automate the creation of methods that satisfy the fmt.Stringer +// interface. Given the name of a (signed or unsigned) integer type T that has constants +// defined, stringer will create a new self-contained Go source file implementing +// +// func (t T) String() string +// +// The file is created in the same package and directory as the package that defines T. +// It has helpful defaults designed for use with go generate. +// +// Stringer works best with constants that are consecutive values such as created using iota, +// but creates good code regardless. In the future it might also provide custom support for +// constant sets that are bit patterns. +// +// For example, given this snippet, +// +// package painkiller +// +// type Pill int +// +// const ( +// Placebo Pill = iota +// Aspirin +// Ibuprofen +// Paracetamol +// Acetaminophen = Paracetamol +// ) +// +// running this command +// +// stringer -type=Pill +// +// in the same directory will create the file pill_string.go, in package painkiller, +// containing a definition of +// +// func (Pill) String() string +// +// That method will translate the value of a Pill constant to the string representation +// of the respective constant name, so that the call fmt.Print(painkiller.Aspirin) will +// print the string "Aspirin". +// +// Typically this process would be run using go generate, like this: +// +// //go:generate stringer -type=Pill +// +// If multiple constants have the same value, the lexically first matching name will +// be used (in the example, Acetaminophen will print as "Paracetamol"). +// +// With no arguments, it processes the package in the current directory. +// Otherwise, the arguments must name a single directory holding a Go package +// or a set of Go source files that represent a single Go package. +// +// The -type flag accepts a comma-separated list of types so a single run can +// generate methods for multiple types. The default output file is t_string.go, +// where t is the lower-cased name of the first type listed. It can be overridden +// with the -output flag. +// +// The -linecomment flag tells stringer to generate the text of any line comment, trimmed +// of leading spaces, instead of the constant name. For instance, if the constants above had a +// Pill prefix, one could write +// +// PillAspirin // Aspirin +// +// to suppress it in the output. +package main // import "golang.org/x/tools/cmd/stringer" + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/constant" + "go/format" + "go/token" + "go/types" + "log" + "os" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +var ( + typeNames = flag.String("type", "", "comma-separated list of type names; must be set") + output = flag.String("output", "", "output file name; default srcdir/_string.go") + trimprefix = flag.String("trimprefix", "", "trim the `prefix` from the generated constant names") + linecomment = flag.Bool("linecomment", false, "use line comment text as printed text when present") + buildTags = flag.String("tags", "", "comma-separated list of build tags to apply") +) + +// Usage is a replacement usage function for the flags package. +func Usage() { + fmt.Fprintf(os.Stderr, "Usage of stringer:\n") + fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T [directory]\n") + fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T files... # Must be a single package\n") + fmt.Fprintf(os.Stderr, "For more information, see:\n") + fmt.Fprintf(os.Stderr, "\thttps://pkg.go.dev/golang.org/x/tools/cmd/stringer\n") + fmt.Fprintf(os.Stderr, "Flags:\n") + flag.PrintDefaults() +} + +func main() { + log.SetFlags(0) + log.SetPrefix("stringer: ") + flag.Usage = Usage + flag.Parse() + if len(*typeNames) == 0 { + flag.Usage() + os.Exit(2) + } + types := strings.Split(*typeNames, ",") + var tags []string + if len(*buildTags) > 0 { + tags = strings.Split(*buildTags, ",") + } + + // We accept either one directory or a list of files. Which do we have? + args := flag.Args() + if len(args) == 0 { + // Default: process whole package in current directory. + args = []string{"."} + } + + // Parse the package once. + var dir string + g := Generator{ + trimPrefix: *trimprefix, + lineComment: *linecomment, + } + // TODO(suzmue): accept other patterns for packages (directories, list of files, import paths, etc). + if len(args) == 1 && isDirectory(args[0]) { + dir = args[0] + } else { + if len(tags) != 0 { + log.Fatal("-tags option applies only to directories, not when files are specified") + } + dir = filepath.Dir(args[0]) + } + + g.parsePackage(args, tags) + + // Print the header and package clause. + g.Printf("// Code generated by \"stringer %s\"; DO NOT EDIT.\n", strings.Join(os.Args[1:], " ")) + g.Printf("\n") + g.Printf("package %s", g.pkg.name) + g.Printf("\n") + g.Printf("import \"strconv\"\n") // Used by all methods. + + // Run generate for each type. + for _, typeName := range types { + g.generate(typeName) + } + + // Format the output. + src := g.format() + + // Write to file. + outputName := *output + if outputName == "" { + baseName := fmt.Sprintf("%s_string.go", types[0]) + outputName = filepath.Join(dir, strings.ToLower(baseName)) + } + err := os.WriteFile(outputName, src, 0644) + if err != nil { + log.Fatalf("writing output: %s", err) + } +} + +// isDirectory reports whether the named file is a directory. +func isDirectory(name string) bool { + info, err := os.Stat(name) + if err != nil { + log.Fatal(err) + } + return info.IsDir() +} + +// Generator holds the state of the analysis. Primarily used to buffer +// the output for format.Source. +type Generator struct { + buf bytes.Buffer // Accumulated output. + pkg *Package // Package we are scanning. + + trimPrefix string + lineComment bool +} + +func (g *Generator) Printf(format string, args ...interface{}) { + fmt.Fprintf(&g.buf, format, args...) +} + +// File holds a single parsed file and associated data. +type File struct { + pkg *Package // Package to which this file belongs. + file *ast.File // Parsed AST. + // These fields are reset for each type being generated. + typeName string // Name of the constant type. + values []Value // Accumulator for constant values of that type. + + trimPrefix string + lineComment bool +} + +type Package struct { + name string + defs map[*ast.Ident]types.Object + files []*File +} + +// parsePackage analyzes the single package constructed from the patterns and tags. +// parsePackage exits if there is an error. +func (g *Generator) parsePackage(patterns []string, tags []string) { + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax, + // TODO: Need to think about constants in test files. Maybe write type_string_test.go + // in a separate pass? For later. + Tests: false, + BuildFlags: []string{fmt.Sprintf("-tags=%s", strings.Join(tags, " "))}, + } + pkgs, err := packages.Load(cfg, patterns...) + if err != nil { + log.Fatal(err) + } + if len(pkgs) != 1 { + log.Fatalf("error: %d packages found", len(pkgs)) + } + g.addPackage(pkgs[0]) +} + +// addPackage adds a type checked Package and its syntax files to the generator. +func (g *Generator) addPackage(pkg *packages.Package) { + g.pkg = &Package{ + name: pkg.Name, + defs: pkg.TypesInfo.Defs, + files: make([]*File, len(pkg.Syntax)), + } + + for i, file := range pkg.Syntax { + g.pkg.files[i] = &File{ + file: file, + pkg: g.pkg, + trimPrefix: g.trimPrefix, + lineComment: g.lineComment, + } + } +} + +// generate produces the String method for the named type. +func (g *Generator) generate(typeName string) { + values := make([]Value, 0, 100) + for _, file := range g.pkg.files { + // Set the state for this run of the walker. + file.typeName = typeName + file.values = nil + if file.file != nil { + ast.Inspect(file.file, file.genDecl) + values = append(values, file.values...) + } + } + + if len(values) == 0 { + log.Fatalf("no values defined for type %s", typeName) + } + // Generate code that will fail if the constants change value. + g.Printf("func _() {\n") + g.Printf("\t// An \"invalid array index\" compiler error signifies that the constant values have changed.\n") + g.Printf("\t// Re-run the stringer command to generate them again.\n") + g.Printf("\tvar x [1]struct{}\n") + for _, v := range values { + g.Printf("\t_ = x[%s - %s]\n", v.originalName, v.str) + } + g.Printf("}\n") + runs := splitIntoRuns(values) + // The decision of which pattern to use depends on the number of + // runs in the numbers. If there's only one, it's easy. For more than + // one, there's a tradeoff between complexity and size of the data + // and code vs. the simplicity of a map. A map takes more space, + // but so does the code. The decision here (crossover at 10) is + // arbitrary, but considers that for large numbers of runs the cost + // of the linear scan in the switch might become important, and + // rather than use yet another algorithm such as binary search, + // we punt and use a map. In any case, the likelihood of a map + // being necessary for any realistic example other than bitmasks + // is very low. And bitmasks probably deserve their own analysis, + // to be done some other day. + switch { + case len(runs) == 1: + g.buildOneRun(runs, typeName) + case len(runs) <= 10: + g.buildMultipleRuns(runs, typeName) + default: + g.buildMap(runs, typeName) + } +} + +// splitIntoRuns breaks the values into runs of contiguous sequences. +// For example, given 1,2,3,5,6,7 it returns {1,2,3},{5,6,7}. +// The input slice is known to be non-empty. +func splitIntoRuns(values []Value) [][]Value { + // We use stable sort so the lexically first name is chosen for equal elements. + sort.Stable(byValue(values)) + // Remove duplicates. Stable sort has put the one we want to print first, + // so use that one. The String method won't care about which named constant + // was the argument, so the first name for the given value is the only one to keep. + // We need to do this because identical values would cause the switch or map + // to fail to compile. + j := 1 + for i := 1; i < len(values); i++ { + if values[i].value != values[i-1].value { + values[j] = values[i] + j++ + } + } + values = values[:j] + runs := make([][]Value, 0, 10) + for len(values) > 0 { + // One contiguous sequence per outer loop. + i := 1 + for i < len(values) && values[i].value == values[i-1].value+1 { + i++ + } + runs = append(runs, values[:i]) + values = values[i:] + } + return runs +} + +// format returns the gofmt-ed contents of the Generator's buffer. +func (g *Generator) format() []byte { + src, err := format.Source(g.buf.Bytes()) + if err != nil { + // Should never happen, but can arise when developing this code. + // The user can compile the output to see the error. + log.Printf("warning: internal error: invalid Go generated: %s", err) + log.Printf("warning: compile the package to analyze the error") + return g.buf.Bytes() + } + return src +} + +// Value represents a declared constant. +type Value struct { + originalName string // The name of the constant. + name string // The name with trimmed prefix. + // The value is stored as a bit pattern alone. The boolean tells us + // whether to interpret it as an int64 or a uint64; the only place + // this matters is when sorting. + // Much of the time the str field is all we need; it is printed + // by Value.String. + value uint64 // Will be converted to int64 when needed. + signed bool // Whether the constant is a signed type. + str string // The string representation given by the "go/constant" package. +} + +func (v *Value) String() string { + return v.str +} + +// byValue lets us sort the constants into increasing order. +// We take care in the Less method to sort in signed or unsigned order, +// as appropriate. +type byValue []Value + +func (b byValue) Len() int { return len(b) } +func (b byValue) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b byValue) Less(i, j int) bool { + if b[i].signed { + return int64(b[i].value) < int64(b[j].value) + } + return b[i].value < b[j].value +} + +// genDecl processes one declaration clause. +func (f *File) genDecl(node ast.Node) bool { + decl, ok := node.(*ast.GenDecl) + if !ok || decl.Tok != token.CONST { + // We only care about const declarations. + return true + } + // The name of the type of the constants we are declaring. + // Can change if this is a multi-element declaration. + typ := "" + // Loop over the elements of the declaration. Each element is a ValueSpec: + // a list of names possibly followed by a type, possibly followed by values. + // If the type and value are both missing, we carry down the type (and value, + // but the "go/types" package takes care of that). + for _, spec := range decl.Specs { + vspec := spec.(*ast.ValueSpec) // Guaranteed to succeed as this is CONST. + if vspec.Type == nil && len(vspec.Values) > 0 { + // "X = 1". With no type but a value. If the constant is untyped, + // skip this vspec and reset the remembered type. + typ = "" + + // If this is a simple type conversion, remember the type. + // We don't mind if this is actually a call; a qualified call won't + // be matched (that will be SelectorExpr, not Ident), and only unusual + // situations will result in a function call that appears to be + // a type conversion. + ce, ok := vspec.Values[0].(*ast.CallExpr) + if !ok { + continue + } + id, ok := ce.Fun.(*ast.Ident) + if !ok { + continue + } + typ = id.Name + } + if vspec.Type != nil { + // "X T". We have a type. Remember it. + ident, ok := vspec.Type.(*ast.Ident) + if !ok { + continue + } + typ = ident.Name + } + if typ != f.typeName { + // This is not the type we're looking for. + continue + } + // We now have a list of names (from one line of source code) all being + // declared with the desired type. + // Grab their names and actual values and store them in f.values. + for _, name := range vspec.Names { + if name.Name == "_" { + continue + } + // This dance lets the type checker find the values for us. It's a + // bit tricky: look up the object declared by the name, find its + // types.Const, and extract its value. + obj, ok := f.pkg.defs[name] + if !ok { + log.Fatalf("no value for constant %s", name) + } + info := obj.Type().Underlying().(*types.Basic).Info() + if info&types.IsInteger == 0 { + log.Fatalf("can't handle non-integer constant type %s", typ) + } + value := obj.(*types.Const).Val() // Guaranteed to succeed as this is CONST. + if value.Kind() != constant.Int { + log.Fatalf("can't happen: constant is not an integer %s", name) + } + i64, isInt := constant.Int64Val(value) + u64, isUint := constant.Uint64Val(value) + if !isInt && !isUint { + log.Fatalf("internal error: value of %s is not an integer: %s", name, value.String()) + } + if !isInt { + u64 = uint64(i64) + } + v := Value{ + originalName: name.Name, + value: u64, + signed: info&types.IsUnsigned == 0, + str: value.String(), + } + if c := vspec.Comment; f.lineComment && c != nil && len(c.List) == 1 { + v.name = strings.TrimSpace(c.Text()) + } else { + v.name = strings.TrimPrefix(v.originalName, f.trimPrefix) + } + f.values = append(f.values, v) + } + } + return false +} + +// Helpers + +// usize returns the number of bits of the smallest unsigned integer +// type that will hold n. Used to create the smallest possible slice of +// integers to use as indexes into the concatenated strings. +func usize(n int) int { + switch { + case n < 1<<8: + return 8 + case n < 1<<16: + return 16 + default: + // 2^32 is enough constants for anyone. + return 32 + } +} + +// declareIndexAndNameVars declares the index slices and concatenated names +// strings representing the runs of values. +func (g *Generator) declareIndexAndNameVars(runs [][]Value, typeName string) { + var indexes, names []string + for i, run := range runs { + index, name := g.createIndexAndNameDecl(run, typeName, fmt.Sprintf("_%d", i)) + if len(run) != 1 { + indexes = append(indexes, index) + } + names = append(names, name) + } + g.Printf("const (\n") + for _, name := range names { + g.Printf("\t%s\n", name) + } + g.Printf(")\n\n") + + if len(indexes) > 0 { + g.Printf("var (") + for _, index := range indexes { + g.Printf("\t%s\n", index) + } + g.Printf(")\n\n") + } +} + +// declareIndexAndNameVar is the single-run version of declareIndexAndNameVars +func (g *Generator) declareIndexAndNameVar(run []Value, typeName string) { + index, name := g.createIndexAndNameDecl(run, typeName, "") + g.Printf("const %s\n", name) + g.Printf("var %s\n", index) +} + +// createIndexAndNameDecl returns the pair of declarations for the run. The caller will add "const" and "var". +func (g *Generator) createIndexAndNameDecl(run []Value, typeName string, suffix string) (string, string) { + b := new(bytes.Buffer) + indexes := make([]int, len(run)) + for i := range run { + b.WriteString(run[i].name) + indexes[i] = b.Len() + } + nameConst := fmt.Sprintf("_%s_name%s = %q", typeName, suffix, b.String()) + nameLen := b.Len() + b.Reset() + fmt.Fprintf(b, "_%s_index%s = [...]uint%d{0, ", typeName, suffix, usize(nameLen)) + for i, v := range indexes { + if i > 0 { + fmt.Fprintf(b, ", ") + } + fmt.Fprintf(b, "%d", v) + } + fmt.Fprintf(b, "}") + return b.String(), nameConst +} + +// declareNameVars declares the concatenated names string representing all the values in the runs. +func (g *Generator) declareNameVars(runs [][]Value, typeName string, suffix string) { + g.Printf("const _%s_name%s = \"", typeName, suffix) + for _, run := range runs { + for i := range run { + g.Printf("%s", run[i].name) + } + } + g.Printf("\"\n") +} + +// buildOneRun generates the variables and String method for a single run of contiguous values. +func (g *Generator) buildOneRun(runs [][]Value, typeName string) { + values := runs[0] + g.Printf("\n") + g.declareIndexAndNameVar(values, typeName) + // The generated code is simple enough to write as a Printf format. + lessThanZero := "" + if values[0].signed { + lessThanZero = "i < 0 || " + } + if values[0].value == 0 { // Signed or unsigned, 0 is still 0. + g.Printf(stringOneRun, typeName, usize(len(values)), lessThanZero) + } else { + g.Printf(stringOneRunWithOffset, typeName, values[0].String(), usize(len(values)), lessThanZero) + } +} + +// Arguments to format are: +// +// [1]: type name +// [2]: size of index element (8 for uint8 etc.) +// [3]: less than zero check (for signed types) +const stringOneRun = `func (i %[1]s) String() string { + if %[3]si >= %[1]s(len(_%[1]s_index)-1) { + return "%[1]s(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _%[1]s_name[_%[1]s_index[i]:_%[1]s_index[i+1]] +} +` + +// Arguments to format are: +// [1]: type name +// [2]: lowest defined value for type, as a string +// [3]: size of index element (8 for uint8 etc.) +// [4]: less than zero check (for signed types) +/* + */ +const stringOneRunWithOffset = `func (i %[1]s) String() string { + i -= %[2]s + if %[4]si >= %[1]s(len(_%[1]s_index)-1) { + return "%[1]s(" + strconv.FormatInt(int64(i + %[2]s), 10) + ")" + } + return _%[1]s_name[_%[1]s_index[i] : _%[1]s_index[i+1]] +} +` + +// buildMultipleRuns generates the variables and String method for multiple runs of contiguous values. +// For this pattern, a single Printf format won't do. +func (g *Generator) buildMultipleRuns(runs [][]Value, typeName string) { + g.Printf("\n") + g.declareIndexAndNameVars(runs, typeName) + g.Printf("func (i %s) String() string {\n", typeName) + g.Printf("\tswitch {\n") + for i, values := range runs { + if len(values) == 1 { + g.Printf("\tcase i == %s:\n", &values[0]) + g.Printf("\t\treturn _%s_name_%d\n", typeName, i) + continue + } + if values[0].value == 0 && !values[0].signed { + // For an unsigned lower bound of 0, "0 <= i" would be redundant. + g.Printf("\tcase i <= %s:\n", &values[len(values)-1]) + } else { + g.Printf("\tcase %s <= i && i <= %s:\n", &values[0], &values[len(values)-1]) + } + if values[0].value != 0 { + g.Printf("\t\ti -= %s\n", &values[0]) + } + g.Printf("\t\treturn _%s_name_%d[_%s_index_%d[i]:_%s_index_%d[i+1]]\n", + typeName, i, typeName, i, typeName, i) + } + g.Printf("\tdefault:\n") + g.Printf("\t\treturn \"%s(\" + strconv.FormatInt(int64(i), 10) + \")\"\n", typeName) + g.Printf("\t}\n") + g.Printf("}\n") +} + +// buildMap handles the case where the space is so sparse a map is a reasonable fallback. +// It's a rare situation but has simple code. +func (g *Generator) buildMap(runs [][]Value, typeName string) { + g.Printf("\n") + g.declareNameVars(runs, typeName, "") + g.Printf("\nvar _%s_map = map[%s]string{\n", typeName, typeName) + n := 0 + for _, values := range runs { + for _, value := range values { + g.Printf("\t%s: _%s_name[%d:%d],\n", &value, typeName, n, n+len(value.name)) + n += len(value.name) + } + } + g.Printf("}\n\n") + g.Printf(stringMap, typeName) +} + +// Argument to format is the type name. +const stringMap = `func (i %[1]s) String() string { + if str, ok := _%[1]s_map[i]; ok { + return str + } + return "%[1]s(" + strconv.FormatInt(int64(i), 10) + ")" +} +` diff --git a/vendor/golang.org/x/tools/go/packages/doc.go b/vendor/golang.org/x/tools/go/packages/doc.go new file mode 100644 index 0000000000..da4ab89fe6 --- /dev/null +++ b/vendor/golang.org/x/tools/go/packages/doc.go @@ -0,0 +1,220 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package packages loads Go packages for inspection and analysis. + +The Load function takes as input a list of patterns and return a list of Package +structs describing individual packages matched by those patterns. +The LoadMode controls the amount of detail in the loaded packages. + +Load passes most patterns directly to the underlying build tool, +but all patterns with the prefix "query=", where query is a +non-empty string of letters from [a-z], are reserved and may be +interpreted as query operators. + +Two query operators are currently supported: "file" and "pattern". + +The query "file=path/to/file.go" matches the package or packages enclosing +the Go source file path/to/file.go. For example "file=~/go/src/fmt/print.go" +might return the packages "fmt" and "fmt [fmt.test]". + +The query "pattern=string" causes "string" to be passed directly to +the underlying build tool. In most cases this is unnecessary, +but an application can use Load("pattern=" + x) as an escaping mechanism +to ensure that x is not interpreted as a query operator if it contains '='. + +All other query operators are reserved for future use and currently +cause Load to report an error. + +The Package struct provides basic information about the package, including + + - ID, a unique identifier for the package in the returned set; + - GoFiles, the names of the package's Go source files; + - Imports, a map from source import strings to the Packages they name; + - Types, the type information for the package's exported symbols; + - Syntax, the parsed syntax trees for the package's source code; and + - TypeInfo, the result of a complete type-check of the package syntax trees. + +(See the documentation for type Package for the complete list of fields +and more detailed descriptions.) + +For example, + + Load(nil, "bytes", "unicode...") + +returns four Package structs describing the standard library packages +bytes, unicode, unicode/utf16, and unicode/utf8. Note that one pattern +can match multiple packages and that a package might be matched by +multiple patterns: in general it is not possible to determine which +packages correspond to which patterns. + +Note that the list returned by Load contains only the packages matched +by the patterns. Their dependencies can be found by walking the import +graph using the Imports fields. + +The Load function can be configured by passing a pointer to a Config as +the first argument. A nil Config is equivalent to the zero Config, which +causes Load to run in LoadFiles mode, collecting minimal information. +See the documentation for type Config for details. + +As noted earlier, the Config.Mode controls the amount of detail +reported about the loaded packages. See the documentation for type LoadMode +for details. + +Most tools should pass their command-line arguments (after any flags) +uninterpreted to the loader, so that the loader can interpret them +according to the conventions of the underlying build system. +See the Example function for typical usage. +*/ +package packages // import "golang.org/x/tools/go/packages" + +/* + +Motivation and design considerations + +The new package's design solves problems addressed by two existing +packages: go/build, which locates and describes packages, and +golang.org/x/tools/go/loader, which loads, parses and type-checks them. +The go/build.Package structure encodes too much of the 'go build' way +of organizing projects, leaving us in need of a data type that describes a +package of Go source code independent of the underlying build system. +We wanted something that works equally well with go build and vgo, and +also other build systems such as Bazel and Blaze, making it possible to +construct analysis tools that work in all these environments. +Tools such as errcheck and staticcheck were essentially unavailable to +the Go community at Google, and some of Google's internal tools for Go +are unavailable externally. +This new package provides a uniform way to obtain package metadata by +querying each of these build systems, optionally supporting their +preferred command-line notations for packages, so that tools integrate +neatly with users' build environments. The Metadata query function +executes an external query tool appropriate to the current workspace. + +Loading packages always returns the complete import graph "all the way down", +even if all you want is information about a single package, because the query +mechanisms of all the build systems we currently support ({go,vgo} list, and +blaze/bazel aspect-based query) cannot provide detailed information +about one package without visiting all its dependencies too, so there is +no additional asymptotic cost to providing transitive information. +(This property might not be true of a hypothetical 5th build system.) + +In calls to TypeCheck, all initial packages, and any package that +transitively depends on one of them, must be loaded from source. +Consider A->B->C->D->E: if A,C are initial, A,B,C must be loaded from +source; D may be loaded from export data, and E may not be loaded at all +(though it's possible that D's export data mentions it, so a +types.Package may be created for it and exposed.) + +The old loader had a feature to suppress type-checking of function +bodies on a per-package basis, primarily intended to reduce the work of +obtaining type information for imported packages. Now that imports are +satisfied by export data, the optimization no longer seems necessary. + +Despite some early attempts, the old loader did not exploit export data, +instead always using the equivalent of WholeProgram mode. This was due +to the complexity of mixing source and export data packages (now +resolved by the upward traversal mentioned above), and because export data +files were nearly always missing or stale. Now that 'go build' supports +caching, all the underlying build systems can guarantee to produce +export data in a reasonable (amortized) time. + +Test "main" packages synthesized by the build system are now reported as +first-class packages, avoiding the need for clients (such as go/ssa) to +reinvent this generation logic. + +One way in which go/packages is simpler than the old loader is in its +treatment of in-package tests. In-package tests are packages that +consist of all the files of the library under test, plus the test files. +The old loader constructed in-package tests by a two-phase process of +mutation called "augmentation": first it would construct and type check +all the ordinary library packages and type-check the packages that +depend on them; then it would add more (test) files to the package and +type-check again. This two-phase approach had four major problems: +1) in processing the tests, the loader modified the library package, + leaving no way for a client application to see both the test + package and the library package; one would mutate into the other. +2) because test files can declare additional methods on types defined in + the library portion of the package, the dispatch of method calls in + the library portion was affected by the presence of the test files. + This should have been a clue that the packages were logically + different. +3) this model of "augmentation" assumed at most one in-package test + per library package, which is true of projects using 'go build', + but not other build systems. +4) because of the two-phase nature of test processing, all packages that + import the library package had to be processed before augmentation, + forcing a "one-shot" API and preventing the client from calling Load + in several times in sequence as is now possible in WholeProgram mode. + (TypeCheck mode has a similar one-shot restriction for a different reason.) + +Early drafts of this package supported "multi-shot" operation. +Although it allowed clients to make a sequence of calls (or concurrent +calls) to Load, building up the graph of Packages incrementally, +it was of marginal value: it complicated the API +(since it allowed some options to vary across calls but not others), +it complicated the implementation, +it cannot be made to work in Types mode, as explained above, +and it was less efficient than making one combined call (when this is possible). +Among the clients we have inspected, none made multiple calls to load +but could not be easily and satisfactorily modified to make only a single call. +However, applications changes may be required. +For example, the ssadump command loads the user-specified packages +and in addition the runtime package. It is tempting to simply append +"runtime" to the user-provided list, but that does not work if the user +specified an ad-hoc package such as [a.go b.go]. +Instead, ssadump no longer requests the runtime package, +but seeks it among the dependencies of the user-specified packages, +and emits an error if it is not found. + +Overlays: The Overlay field in the Config allows providing alternate contents +for Go source files, by providing a mapping from file path to contents. +go/packages will pull in new imports added in overlay files when go/packages +is run in LoadImports mode or greater. +Overlay support for the go list driver isn't complete yet: if the file doesn't +exist on disk, it will only be recognized in an overlay if it is a non-test file +and the package would be reported even without the overlay. + +Questions & Tasks + +- Add GOARCH/GOOS? + They are not portable concepts, but could be made portable. + Our goal has been to allow users to express themselves using the conventions + of the underlying build system: if the build system honors GOARCH + during a build and during a metadata query, then so should + applications built atop that query mechanism. + Conversely, if the target architecture of the build is determined by + command-line flags, the application can pass the relevant + flags through to the build system using a command such as: + myapp -query_flag="--cpu=amd64" -query_flag="--os=darwin" + However, this approach is low-level, unwieldy, and non-portable. + GOOS and GOARCH seem important enough to warrant a dedicated option. + +- How should we handle partial failures such as a mixture of good and + malformed patterns, existing and non-existent packages, successful and + failed builds, import failures, import cycles, and so on, in a call to + Load? + +- Support bazel, blaze, and go1.10 list, not just go1.11 list. + +- Handle (and test) various partial success cases, e.g. + a mixture of good packages and: + invalid patterns + nonexistent packages + empty packages + packages with malformed package or import declarations + unreadable files + import cycles + other parse errors + type errors + Make sure we record errors at the correct place in the graph. + +- Missing packages among initial arguments are not reported. + Return bogus packages for them, like golist does. + +- "undeclared name" errors (for example) are reported out of source file + order. I suspect this is due to the breadth-first resolution now used + by go/types. Is that a bug? Discuss with gri. + +*/ diff --git a/vendor/golang.org/x/tools/go/packages/golist.go b/vendor/golang.org/x/tools/go/packages/golist.go new file mode 100644 index 0000000000..b5de9cf9f2 --- /dev/null +++ b/vendor/golang.org/x/tools/go/packages/golist.go @@ -0,0 +1,1182 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packages + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "os" + "path" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "unicode" + + exec "golang.org/x/sys/execabs" + "golang.org/x/tools/go/internal/packagesdriver" + "golang.org/x/tools/internal/gocommand" + "golang.org/x/tools/internal/packagesinternal" +) + +// debug controls verbose logging. +var debug, _ = strconv.ParseBool(os.Getenv("GOPACKAGESDEBUG")) + +// A goTooOldError reports that the go command +// found by exec.LookPath is too old to use the new go list behavior. +type goTooOldError struct { + error +} + +// responseDeduper wraps a driverResponse, deduplicating its contents. +type responseDeduper struct { + seenRoots map[string]bool + seenPackages map[string]*Package + dr *driverResponse +} + +func newDeduper() *responseDeduper { + return &responseDeduper{ + dr: &driverResponse{}, + seenRoots: map[string]bool{}, + seenPackages: map[string]*Package{}, + } +} + +// addAll fills in r with a driverResponse. +func (r *responseDeduper) addAll(dr *driverResponse) { + for _, pkg := range dr.Packages { + r.addPackage(pkg) + } + for _, root := range dr.Roots { + r.addRoot(root) + } + r.dr.GoVersion = dr.GoVersion +} + +func (r *responseDeduper) addPackage(p *Package) { + if r.seenPackages[p.ID] != nil { + return + } + r.seenPackages[p.ID] = p + r.dr.Packages = append(r.dr.Packages, p) +} + +func (r *responseDeduper) addRoot(id string) { + if r.seenRoots[id] { + return + } + r.seenRoots[id] = true + r.dr.Roots = append(r.dr.Roots, id) +} + +type golistState struct { + cfg *Config + ctx context.Context + + envOnce sync.Once + goEnvError error + goEnv map[string]string + + rootsOnce sync.Once + rootDirsError error + rootDirs map[string]string + + goVersionOnce sync.Once + goVersionError error + goVersion int // The X in Go 1.X. + + // vendorDirs caches the (non)existence of vendor directories. + vendorDirs map[string]bool +} + +// getEnv returns Go environment variables. Only specific variables are +// populated -- computing all of them is slow. +func (state *golistState) getEnv() (map[string]string, error) { + state.envOnce.Do(func() { + var b *bytes.Buffer + b, state.goEnvError = state.invokeGo("env", "-json", "GOMOD", "GOPATH") + if state.goEnvError != nil { + return + } + + state.goEnv = make(map[string]string) + decoder := json.NewDecoder(b) + if state.goEnvError = decoder.Decode(&state.goEnv); state.goEnvError != nil { + return + } + }) + return state.goEnv, state.goEnvError +} + +// mustGetEnv is a convenience function that can be used if getEnv has already succeeded. +func (state *golistState) mustGetEnv() map[string]string { + env, err := state.getEnv() + if err != nil { + panic(fmt.Sprintf("mustGetEnv: %v", err)) + } + return env +} + +// goListDriver uses the go list command to interpret the patterns and produce +// the build system package structure. +// See driver for more details. +func goListDriver(cfg *Config, patterns ...string) (*driverResponse, error) { + // Make sure that any asynchronous go commands are killed when we return. + parentCtx := cfg.Context + if parentCtx == nil { + parentCtx = context.Background() + } + ctx, cancel := context.WithCancel(parentCtx) + defer cancel() + + response := newDeduper() + + state := &golistState{ + cfg: cfg, + ctx: ctx, + vendorDirs: map[string]bool{}, + } + + // Fill in response.Sizes asynchronously if necessary. + var sizeserr error + var sizeswg sync.WaitGroup + if cfg.Mode&NeedTypesSizes != 0 || cfg.Mode&NeedTypes != 0 { + sizeswg.Add(1) + go func() { + compiler, arch, err := packagesdriver.GetSizesForArgsGolist(ctx, state.cfgInvocation(), cfg.gocmdRunner) + sizeserr = err + response.dr.Compiler = compiler + response.dr.Arch = arch + sizeswg.Done() + }() + } + + // Determine files requested in contains patterns + var containFiles []string + restPatterns := make([]string, 0, len(patterns)) + // Extract file= and other [querytype]= patterns. Report an error if querytype + // doesn't exist. +extractQueries: + for _, pattern := range patterns { + eqidx := strings.Index(pattern, "=") + if eqidx < 0 { + restPatterns = append(restPatterns, pattern) + } else { + query, value := pattern[:eqidx], pattern[eqidx+len("="):] + switch query { + case "file": + containFiles = append(containFiles, value) + case "pattern": + restPatterns = append(restPatterns, value) + case "": // not a reserved query + restPatterns = append(restPatterns, pattern) + default: + for _, rune := range query { + if rune < 'a' || rune > 'z' { // not a reserved query + restPatterns = append(restPatterns, pattern) + continue extractQueries + } + } + // Reject all other patterns containing "=" + return nil, fmt.Errorf("invalid query type %q in query pattern %q", query, pattern) + } + } + } + + // See if we have any patterns to pass through to go list. Zero initial + // patterns also requires a go list call, since it's the equivalent of + // ".". + if len(restPatterns) > 0 || len(patterns) == 0 { + dr, err := state.createDriverResponse(restPatterns...) + if err != nil { + return nil, err + } + response.addAll(dr) + } + + if len(containFiles) != 0 { + if err := state.runContainsQueries(response, containFiles); err != nil { + return nil, err + } + } + + // Only use go/packages' overlay processing if we're using a Go version + // below 1.16. Otherwise, go list handles it. + if goVersion, err := state.getGoVersion(); err == nil && goVersion < 16 { + modifiedPkgs, needPkgs, err := state.processGolistOverlay(response) + if err != nil { + return nil, err + } + + var containsCandidates []string + if len(containFiles) > 0 { + containsCandidates = append(containsCandidates, modifiedPkgs...) + containsCandidates = append(containsCandidates, needPkgs...) + } + if err := state.addNeededOverlayPackages(response, needPkgs); err != nil { + return nil, err + } + // Check candidate packages for containFiles. + if len(containFiles) > 0 { + for _, id := range containsCandidates { + pkg, ok := response.seenPackages[id] + if !ok { + response.addPackage(&Package{ + ID: id, + Errors: []Error{{ + Kind: ListError, + Msg: fmt.Sprintf("package %s expected but not seen", id), + }}, + }) + continue + } + for _, f := range containFiles { + for _, g := range pkg.GoFiles { + if sameFile(f, g) { + response.addRoot(id) + } + } + } + } + } + // Add root for any package that matches a pattern. This applies only to + // packages that are modified by overlays, since they are not added as + // roots automatically. + for _, pattern := range restPatterns { + match := matchPattern(pattern) + for _, pkgID := range modifiedPkgs { + pkg, ok := response.seenPackages[pkgID] + if !ok { + continue + } + if match(pkg.PkgPath) { + response.addRoot(pkg.ID) + } + } + } + } + + sizeswg.Wait() + if sizeserr != nil { + return nil, sizeserr + } + return response.dr, nil +} + +func (state *golistState) addNeededOverlayPackages(response *responseDeduper, pkgs []string) error { + if len(pkgs) == 0 { + return nil + } + dr, err := state.createDriverResponse(pkgs...) + if err != nil { + return err + } + for _, pkg := range dr.Packages { + response.addPackage(pkg) + } + _, needPkgs, err := state.processGolistOverlay(response) + if err != nil { + return err + } + return state.addNeededOverlayPackages(response, needPkgs) +} + +func (state *golistState) runContainsQueries(response *responseDeduper, queries []string) error { + for _, query := range queries { + // TODO(matloob): Do only one query per directory. + fdir := filepath.Dir(query) + // Pass absolute path of directory to go list so that it knows to treat it as a directory, + // not a package path. + pattern, err := filepath.Abs(fdir) + if err != nil { + return fmt.Errorf("could not determine absolute path of file= query path %q: %v", query, err) + } + dirResponse, err := state.createDriverResponse(pattern) + + // If there was an error loading the package, or no packages are returned, + // or the package is returned with errors, try to load the file as an + // ad-hoc package. + // Usually the error will appear in a returned package, but may not if we're + // in module mode and the ad-hoc is located outside a module. + if err != nil || len(dirResponse.Packages) == 0 || len(dirResponse.Packages) == 1 && len(dirResponse.Packages[0].GoFiles) == 0 && + len(dirResponse.Packages[0].Errors) == 1 { + var queryErr error + if dirResponse, queryErr = state.adhocPackage(pattern, query); queryErr != nil { + return err // return the original error + } + } + isRoot := make(map[string]bool, len(dirResponse.Roots)) + for _, root := range dirResponse.Roots { + isRoot[root] = true + } + for _, pkg := range dirResponse.Packages { + // Add any new packages to the main set + // We don't bother to filter packages that will be dropped by the changes of roots, + // that will happen anyway during graph construction outside this function. + // Over-reporting packages is not a problem. + response.addPackage(pkg) + // if the package was not a root one, it cannot have the file + if !isRoot[pkg.ID] { + continue + } + for _, pkgFile := range pkg.GoFiles { + if filepath.Base(query) == filepath.Base(pkgFile) { + response.addRoot(pkg.ID) + break + } + } + } + } + return nil +} + +// adhocPackage attempts to load or construct an ad-hoc package for a given +// query, if the original call to the driver produced inadequate results. +func (state *golistState) adhocPackage(pattern, query string) (*driverResponse, error) { + response, err := state.createDriverResponse(query) + if err != nil { + return nil, err + } + // If we get nothing back from `go list`, + // try to make this file into its own ad-hoc package. + // TODO(rstambler): Should this check against the original response? + if len(response.Packages) == 0 { + response.Packages = append(response.Packages, &Package{ + ID: "command-line-arguments", + PkgPath: query, + GoFiles: []string{query}, + CompiledGoFiles: []string{query}, + Imports: make(map[string]*Package), + }) + response.Roots = append(response.Roots, "command-line-arguments") + } + // Handle special cases. + if len(response.Packages) == 1 { + // golang/go#33482: If this is a file= query for ad-hoc packages where + // the file only exists on an overlay, and exists outside of a module, + // add the file to the package and remove the errors. + if response.Packages[0].ID == "command-line-arguments" || + filepath.ToSlash(response.Packages[0].PkgPath) == filepath.ToSlash(query) { + if len(response.Packages[0].GoFiles) == 0 { + filename := filepath.Join(pattern, filepath.Base(query)) // avoid recomputing abspath + // TODO(matloob): check if the file is outside of a root dir? + for path := range state.cfg.Overlay { + if path == filename { + response.Packages[0].Errors = nil + response.Packages[0].GoFiles = []string{path} + response.Packages[0].CompiledGoFiles = []string{path} + } + } + } + } + } + return response, nil +} + +// Fields must match go list; +// see $GOROOT/src/cmd/go/internal/load/pkg.go. +type jsonPackage struct { + ImportPath string + Dir string + Name string + Export string + GoFiles []string + CompiledGoFiles []string + IgnoredGoFiles []string + IgnoredOtherFiles []string + EmbedPatterns []string + EmbedFiles []string + CFiles []string + CgoFiles []string + CXXFiles []string + MFiles []string + HFiles []string + FFiles []string + SFiles []string + SwigFiles []string + SwigCXXFiles []string + SysoFiles []string + Imports []string + ImportMap map[string]string + Deps []string + Module *Module + TestGoFiles []string + TestImports []string + XTestGoFiles []string + XTestImports []string + ForTest string // q in a "p [q.test]" package, else "" + DepOnly bool + + Error *packagesinternal.PackageError + DepsErrors []*packagesinternal.PackageError +} + +type jsonPackageError struct { + ImportStack []string + Pos string + Err string +} + +func otherFiles(p *jsonPackage) [][]string { + return [][]string{p.CFiles, p.CXXFiles, p.MFiles, p.HFiles, p.FFiles, p.SFiles, p.SwigFiles, p.SwigCXXFiles, p.SysoFiles} +} + +// createDriverResponse uses the "go list" command to expand the pattern +// words and return a response for the specified packages. +func (state *golistState) createDriverResponse(words ...string) (*driverResponse, error) { + // go list uses the following identifiers in ImportPath and Imports: + // + // "p" -- importable package or main (command) + // "q.test" -- q's test executable + // "p [q.test]" -- variant of p as built for q's test executable + // "q_test [q.test]" -- q's external test package + // + // The packages p that are built differently for a test q.test + // are q itself, plus any helpers used by the external test q_test, + // typically including "testing" and all its dependencies. + + // Run "go list" for complete + // information on the specified packages. + goVersion, err := state.getGoVersion() + if err != nil { + return nil, err + } + buf, err := state.invokeGo("list", golistargs(state.cfg, words, goVersion)...) + if err != nil { + return nil, err + } + + seen := make(map[string]*jsonPackage) + pkgs := make(map[string]*Package) + additionalErrors := make(map[string][]Error) + // Decode the JSON and convert it to Package form. + response := &driverResponse{ + GoVersion: goVersion, + } + for dec := json.NewDecoder(buf); dec.More(); { + p := new(jsonPackage) + if err := dec.Decode(p); err != nil { + return nil, fmt.Errorf("JSON decoding failed: %v", err) + } + + if p.ImportPath == "" { + // The documentation for go list says that “[e]rroneous packages will have + // a non-empty ImportPath”. If for some reason it comes back empty, we + // prefer to error out rather than silently discarding data or handing + // back a package without any way to refer to it. + if p.Error != nil { + return nil, Error{ + Pos: p.Error.Pos, + Msg: p.Error.Err, + } + } + return nil, fmt.Errorf("package missing import path: %+v", p) + } + + // Work around https://golang.org/issue/33157: + // go list -e, when given an absolute path, will find the package contained at + // that directory. But when no package exists there, it will return a fake package + // with an error and the ImportPath set to the absolute path provided to go list. + // Try to convert that absolute path to what its package path would be if it's + // contained in a known module or GOPATH entry. This will allow the package to be + // properly "reclaimed" when overlays are processed. + if filepath.IsAbs(p.ImportPath) && p.Error != nil { + pkgPath, ok, err := state.getPkgPath(p.ImportPath) + if err != nil { + return nil, err + } + if ok { + p.ImportPath = pkgPath + } + } + + if old, found := seen[p.ImportPath]; found { + // If one version of the package has an error, and the other doesn't, assume + // that this is a case where go list is reporting a fake dependency variant + // of the imported package: When a package tries to invalidly import another + // package, go list emits a variant of the imported package (with the same + // import path, but with an error on it, and the package will have a + // DepError set on it). An example of when this can happen is for imports of + // main packages: main packages can not be imported, but they may be + // separately matched and listed by another pattern. + // See golang.org/issue/36188 for more details. + + // The plan is that eventually, hopefully in Go 1.15, the error will be + // reported on the importing package rather than the duplicate "fake" + // version of the imported package. Once all supported versions of Go + // have the new behavior this logic can be deleted. + // TODO(matloob): delete the workaround logic once all supported versions of + // Go return the errors on the proper package. + + // There should be exactly one version of a package that doesn't have an + // error. + if old.Error == nil && p.Error == nil { + if !reflect.DeepEqual(p, old) { + return nil, fmt.Errorf("internal error: go list gives conflicting information for package %v", p.ImportPath) + } + continue + } + + // Determine if this package's error needs to be bubbled up. + // This is a hack, and we expect for go list to eventually set the error + // on the package. + if old.Error != nil { + var errkind string + if strings.Contains(old.Error.Err, "not an importable package") { + errkind = "not an importable package" + } else if strings.Contains(old.Error.Err, "use of internal package") && strings.Contains(old.Error.Err, "not allowed") { + errkind = "use of internal package not allowed" + } + if errkind != "" { + if len(old.Error.ImportStack) < 1 { + return nil, fmt.Errorf(`internal error: go list gave a %q error with empty import stack`, errkind) + } + importingPkg := old.Error.ImportStack[len(old.Error.ImportStack)-1] + if importingPkg == old.ImportPath { + // Using an older version of Go which put this package itself on top of import + // stack, instead of the importer. Look for importer in second from top + // position. + if len(old.Error.ImportStack) < 2 { + return nil, fmt.Errorf(`internal error: go list gave a %q error with an import stack without importing package`, errkind) + } + importingPkg = old.Error.ImportStack[len(old.Error.ImportStack)-2] + } + additionalErrors[importingPkg] = append(additionalErrors[importingPkg], Error{ + Pos: old.Error.Pos, + Msg: old.Error.Err, + Kind: ListError, + }) + } + } + + // Make sure that if there's a version of the package without an error, + // that's the one reported to the user. + if old.Error == nil { + continue + } + + // This package will replace the old one at the end of the loop. + } + seen[p.ImportPath] = p + + pkg := &Package{ + Name: p.Name, + ID: p.ImportPath, + GoFiles: absJoin(p.Dir, p.GoFiles, p.CgoFiles), + CompiledGoFiles: absJoin(p.Dir, p.CompiledGoFiles), + OtherFiles: absJoin(p.Dir, otherFiles(p)...), + EmbedFiles: absJoin(p.Dir, p.EmbedFiles), + EmbedPatterns: absJoin(p.Dir, p.EmbedPatterns), + IgnoredFiles: absJoin(p.Dir, p.IgnoredGoFiles, p.IgnoredOtherFiles), + forTest: p.ForTest, + depsErrors: p.DepsErrors, + Module: p.Module, + } + + if (state.cfg.Mode&typecheckCgo) != 0 && len(p.CgoFiles) != 0 { + if len(p.CompiledGoFiles) > len(p.GoFiles) { + // We need the cgo definitions, which are in the first + // CompiledGoFile after the non-cgo ones. This is a hack but there + // isn't currently a better way to find it. We also need the pure + // Go files and unprocessed cgo files, all of which are already + // in pkg.GoFiles. + cgoTypes := p.CompiledGoFiles[len(p.GoFiles)] + pkg.CompiledGoFiles = append([]string{cgoTypes}, pkg.GoFiles...) + } else { + // golang/go#38990: go list silently fails to do cgo processing + pkg.CompiledGoFiles = nil + pkg.Errors = append(pkg.Errors, Error{ + Msg: "go list failed to return CompiledGoFiles. This may indicate failure to perform cgo processing; try building at the command line. See https://golang.org/issue/38990.", + Kind: ListError, + }) + } + } + + // Work around https://golang.org/issue/28749: + // cmd/go puts assembly, C, and C++ files in CompiledGoFiles. + // Remove files from CompiledGoFiles that are non-go files + // (or are not files that look like they are from the cache). + if len(pkg.CompiledGoFiles) > 0 { + out := pkg.CompiledGoFiles[:0] + for _, f := range pkg.CompiledGoFiles { + if ext := filepath.Ext(f); ext != ".go" && ext != "" { // ext == "" means the file is from the cache, so probably cgo-processed file + continue + } + out = append(out, f) + } + pkg.CompiledGoFiles = out + } + + // Extract the PkgPath from the package's ID. + if i := strings.IndexByte(pkg.ID, ' '); i >= 0 { + pkg.PkgPath = pkg.ID[:i] + } else { + pkg.PkgPath = pkg.ID + } + + if pkg.PkgPath == "unsafe" { + pkg.CompiledGoFiles = nil // ignore fake unsafe.go file (#59929) + } else if len(pkg.CompiledGoFiles) == 0 { + // Work around for pre-go.1.11 versions of go list. + // TODO(matloob): they should be handled by the fallback. + // Can we delete this? + pkg.CompiledGoFiles = pkg.GoFiles + } + + // Assume go list emits only absolute paths for Dir. + if p.Dir != "" && !filepath.IsAbs(p.Dir) { + log.Fatalf("internal error: go list returned non-absolute Package.Dir: %s", p.Dir) + } + + if p.Export != "" && !filepath.IsAbs(p.Export) { + pkg.ExportFile = filepath.Join(p.Dir, p.Export) + } else { + pkg.ExportFile = p.Export + } + + // imports + // + // Imports contains the IDs of all imported packages. + // ImportsMap records (path, ID) only where they differ. + ids := make(map[string]bool) + for _, id := range p.Imports { + ids[id] = true + } + pkg.Imports = make(map[string]*Package) + for path, id := range p.ImportMap { + pkg.Imports[path] = &Package{ID: id} // non-identity import + delete(ids, id) + } + for id := range ids { + if id == "C" { + continue + } + + pkg.Imports[id] = &Package{ID: id} // identity import + } + if !p.DepOnly { + response.Roots = append(response.Roots, pkg.ID) + } + + // Temporary work-around for golang/go#39986. Parse filenames out of + // error messages. This happens if there are unrecoverable syntax + // errors in the source, so we can't match on a specific error message. + // + // TODO(rfindley): remove this heuristic, in favor of considering + // InvalidGoFiles from the list driver. + if err := p.Error; err != nil && state.shouldAddFilenameFromError(p) { + addFilenameFromPos := func(pos string) bool { + split := strings.Split(pos, ":") + if len(split) < 1 { + return false + } + filename := strings.TrimSpace(split[0]) + if filename == "" { + return false + } + if !filepath.IsAbs(filename) { + filename = filepath.Join(state.cfg.Dir, filename) + } + info, _ := os.Stat(filename) + if info == nil { + return false + } + pkg.CompiledGoFiles = append(pkg.CompiledGoFiles, filename) + pkg.GoFiles = append(pkg.GoFiles, filename) + return true + } + found := addFilenameFromPos(err.Pos) + // In some cases, go list only reports the error position in the + // error text, not the error position. One such case is when the + // file's package name is a keyword (see golang.org/issue/39763). + if !found { + addFilenameFromPos(err.Err) + } + } + + if p.Error != nil { + msg := strings.TrimSpace(p.Error.Err) // Trim to work around golang.org/issue/32363. + // Address golang.org/issue/35964 by appending import stack to error message. + if msg == "import cycle not allowed" && len(p.Error.ImportStack) != 0 { + msg += fmt.Sprintf(": import stack: %v", p.Error.ImportStack) + } + pkg.Errors = append(pkg.Errors, Error{ + Pos: p.Error.Pos, + Msg: msg, + Kind: ListError, + }) + } + + pkgs[pkg.ID] = pkg + } + + for id, errs := range additionalErrors { + if p, ok := pkgs[id]; ok { + p.Errors = append(p.Errors, errs...) + } + } + for _, pkg := range pkgs { + response.Packages = append(response.Packages, pkg) + } + sort.Slice(response.Packages, func(i, j int) bool { return response.Packages[i].ID < response.Packages[j].ID }) + + return response, nil +} + +func (state *golistState) shouldAddFilenameFromError(p *jsonPackage) bool { + if len(p.GoFiles) > 0 || len(p.CompiledGoFiles) > 0 { + return false + } + + goV, err := state.getGoVersion() + if err != nil { + return false + } + + // On Go 1.14 and earlier, only add filenames from errors if the import stack is empty. + // The import stack behaves differently for these versions than newer Go versions. + if goV < 15 { + return len(p.Error.ImportStack) == 0 + } + + // On Go 1.15 and later, only parse filenames out of error if there's no import stack, + // or the current package is at the top of the import stack. This is not guaranteed + // to work perfectly, but should avoid some cases where files in errors don't belong to this + // package. + return len(p.Error.ImportStack) == 0 || p.Error.ImportStack[len(p.Error.ImportStack)-1] == p.ImportPath +} + +// getGoVersion returns the effective minor version of the go command. +func (state *golistState) getGoVersion() (int, error) { + state.goVersionOnce.Do(func() { + state.goVersion, state.goVersionError = gocommand.GoVersion(state.ctx, state.cfgInvocation(), state.cfg.gocmdRunner) + }) + return state.goVersion, state.goVersionError +} + +// getPkgPath finds the package path of a directory if it's relative to a root +// directory. +func (state *golistState) getPkgPath(dir string) (string, bool, error) { + absDir, err := filepath.Abs(dir) + if err != nil { + return "", false, err + } + roots, err := state.determineRootDirs() + if err != nil { + return "", false, err + } + + for rdir, rpath := range roots { + // Make sure that the directory is in the module, + // to avoid creating a path relative to another module. + if !strings.HasPrefix(absDir, rdir) { + continue + } + // TODO(matloob): This doesn't properly handle symlinks. + r, err := filepath.Rel(rdir, dir) + if err != nil { + continue + } + if rpath != "" { + // We choose only one root even though the directory even it can belong in multiple modules + // or GOPATH entries. This is okay because we only need to work with absolute dirs when a + // file is missing from disk, for instance when gopls calls go/packages in an overlay. + // Once the file is saved, gopls, or the next invocation of the tool will get the correct + // result straight from golist. + // TODO(matloob): Implement module tiebreaking? + return path.Join(rpath, filepath.ToSlash(r)), true, nil + } + return filepath.ToSlash(r), true, nil + } + return "", false, nil +} + +// absJoin absolutizes and flattens the lists of files. +func absJoin(dir string, fileses ...[]string) (res []string) { + for _, files := range fileses { + for _, file := range files { + if !filepath.IsAbs(file) { + file = filepath.Join(dir, file) + } + res = append(res, file) + } + } + return res +} + +func jsonFlag(cfg *Config, goVersion int) string { + if goVersion < 19 { + return "-json" + } + var fields []string + added := make(map[string]bool) + addFields := func(fs ...string) { + for _, f := range fs { + if !added[f] { + added[f] = true + fields = append(fields, f) + } + } + } + addFields("Name", "ImportPath", "Error") // These fields are always needed + if cfg.Mode&NeedFiles != 0 || cfg.Mode&NeedTypes != 0 { + addFields("Dir", "GoFiles", "IgnoredGoFiles", "IgnoredOtherFiles", "CFiles", + "CgoFiles", "CXXFiles", "MFiles", "HFiles", "FFiles", "SFiles", + "SwigFiles", "SwigCXXFiles", "SysoFiles") + if cfg.Tests { + addFields("TestGoFiles", "XTestGoFiles") + } + } + if cfg.Mode&NeedTypes != 0 { + // CompiledGoFiles seems to be required for the test case TestCgoNoSyntax, + // even when -compiled isn't passed in. + // TODO(#52435): Should we make the test ask for -compiled, or automatically + // request CompiledGoFiles in certain circumstances? + addFields("Dir", "CompiledGoFiles") + } + if cfg.Mode&NeedCompiledGoFiles != 0 { + addFields("Dir", "CompiledGoFiles", "Export") + } + if cfg.Mode&NeedImports != 0 { + // When imports are requested, DepOnly is used to distinguish between packages + // explicitly requested and transitive imports of those packages. + addFields("DepOnly", "Imports", "ImportMap") + if cfg.Tests { + addFields("TestImports", "XTestImports") + } + } + if cfg.Mode&NeedDeps != 0 { + addFields("DepOnly") + } + if usesExportData(cfg) { + // Request Dir in the unlikely case Export is not absolute. + addFields("Dir", "Export") + } + if cfg.Mode&needInternalForTest != 0 { + addFields("ForTest") + } + if cfg.Mode&needInternalDepsErrors != 0 { + addFields("DepsErrors") + } + if cfg.Mode&NeedModule != 0 { + addFields("Module") + } + if cfg.Mode&NeedEmbedFiles != 0 { + addFields("EmbedFiles") + } + if cfg.Mode&NeedEmbedPatterns != 0 { + addFields("EmbedPatterns") + } + return "-json=" + strings.Join(fields, ",") +} + +func golistargs(cfg *Config, words []string, goVersion int) []string { + const findFlags = NeedImports | NeedTypes | NeedSyntax | NeedTypesInfo + fullargs := []string{ + "-e", jsonFlag(cfg, goVersion), + fmt.Sprintf("-compiled=%t", cfg.Mode&(NeedCompiledGoFiles|NeedSyntax|NeedTypes|NeedTypesInfo|NeedTypesSizes) != 0), + fmt.Sprintf("-test=%t", cfg.Tests), + fmt.Sprintf("-export=%t", usesExportData(cfg)), + fmt.Sprintf("-deps=%t", cfg.Mode&NeedImports != 0), + // go list doesn't let you pass -test and -find together, + // probably because you'd just get the TestMain. + fmt.Sprintf("-find=%t", !cfg.Tests && cfg.Mode&findFlags == 0 && !usesExportData(cfg)), + } + + // golang/go#60456: with go1.21 and later, go list serves pgo variants, which + // can be costly to compute and may result in redundant processing for the + // caller. Disable these variants. If someone wants to add e.g. a NeedPGO + // mode flag, that should be a separate proposal. + if goVersion >= 21 { + fullargs = append(fullargs, "-pgo=off") + } + + fullargs = append(fullargs, cfg.BuildFlags...) + fullargs = append(fullargs, "--") + fullargs = append(fullargs, words...) + return fullargs +} + +// cfgInvocation returns an Invocation that reflects cfg's settings. +func (state *golistState) cfgInvocation() gocommand.Invocation { + cfg := state.cfg + return gocommand.Invocation{ + BuildFlags: cfg.BuildFlags, + ModFile: cfg.modFile, + ModFlag: cfg.modFlag, + CleanEnv: cfg.Env != nil, + Env: cfg.Env, + Logf: cfg.Logf, + WorkingDir: cfg.Dir, + } +} + +// invokeGo returns the stdout of a go command invocation. +func (state *golistState) invokeGo(verb string, args ...string) (*bytes.Buffer, error) { + cfg := state.cfg + + inv := state.cfgInvocation() + + // For Go versions 1.16 and above, `go list` accepts overlays directly via + // the -overlay flag. Set it, if it's available. + // + // The check for "list" is not necessarily required, but we should avoid + // getting the go version if possible. + if verb == "list" { + goVersion, err := state.getGoVersion() + if err != nil { + return nil, err + } + if goVersion >= 16 { + filename, cleanup, err := state.writeOverlays() + if err != nil { + return nil, err + } + defer cleanup() + inv.Overlay = filename + } + } + inv.Verb = verb + inv.Args = args + gocmdRunner := cfg.gocmdRunner + if gocmdRunner == nil { + gocmdRunner = &gocommand.Runner{} + } + stdout, stderr, friendlyErr, err := gocmdRunner.RunRaw(cfg.Context, inv) + if err != nil { + // Check for 'go' executable not being found. + if ee, ok := err.(*exec.Error); ok && ee.Err == exec.ErrNotFound { + return nil, fmt.Errorf("'go list' driver requires 'go', but %s", exec.ErrNotFound) + } + + exitErr, ok := err.(*exec.ExitError) + if !ok { + // Catastrophic error: + // - context cancellation + return nil, fmt.Errorf("couldn't run 'go': %w", err) + } + + // Old go version? + if strings.Contains(stderr.String(), "flag provided but not defined") { + return nil, goTooOldError{fmt.Errorf("unsupported version of go: %s: %s", exitErr, stderr)} + } + + // Related to #24854 + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), "unexpected directory layout") { + return nil, friendlyErr + } + + // Is there an error running the C compiler in cgo? This will be reported in the "Error" field + // and should be suppressed by go list -e. + // + // This condition is not perfect yet because the error message can include other error messages than runtime/cgo. + isPkgPathRune := func(r rune) bool { + // From https://golang.org/ref/spec#Import_declarations: + // Implementation restriction: A compiler may restrict ImportPaths to non-empty strings + // using only characters belonging to Unicode's L, M, N, P, and S general categories + // (the Graphic characters without spaces) and may also exclude the + // characters !"#$%&'()*,:;<=>?[\]^`{|} and the Unicode replacement character U+FFFD. + return unicode.IsOneOf([]*unicode.RangeTable{unicode.L, unicode.M, unicode.N, unicode.P, unicode.S}, r) && + !strings.ContainsRune("!\"#$%&'()*,:;<=>?[\\]^`{|}\uFFFD", r) + } + // golang/go#36770: Handle case where cmd/go prints module download messages before the error. + msg := stderr.String() + for strings.HasPrefix(msg, "go: downloading") { + msg = msg[strings.IndexRune(msg, '\n')+1:] + } + if len(stderr.String()) > 0 && strings.HasPrefix(stderr.String(), "# ") { + msg := msg[len("# "):] + if strings.HasPrefix(strings.TrimLeftFunc(msg, isPkgPathRune), "\n") { + return stdout, nil + } + // Treat pkg-config errors as a special case (golang.org/issue/36770). + if strings.HasPrefix(msg, "pkg-config") { + return stdout, nil + } + } + + // This error only appears in stderr. See golang.org/cl/166398 for a fix in go list to show + // the error in the Err section of stdout in case -e option is provided. + // This fix is provided for backwards compatibility. + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), "named files must be .go files") { + output := fmt.Sprintf(`{"ImportPath": "command-line-arguments","Incomplete": true,"Error": {"Pos": "","Err": %q}}`, + strings.Trim(stderr.String(), "\n")) + return bytes.NewBufferString(output), nil + } + + // Similar to the previous error, but currently lacks a fix in Go. + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), "named files must all be in one directory") { + output := fmt.Sprintf(`{"ImportPath": "command-line-arguments","Incomplete": true,"Error": {"Pos": "","Err": %q}}`, + strings.Trim(stderr.String(), "\n")) + return bytes.NewBufferString(output), nil + } + + // Backwards compatibility for Go 1.11 because 1.12 and 1.13 put the directory in the ImportPath. + // If the package doesn't exist, put the absolute path of the directory into the error message, + // as Go 1.13 list does. + const noSuchDirectory = "no such directory" + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), noSuchDirectory) { + errstr := stderr.String() + abspath := strings.TrimSpace(errstr[strings.Index(errstr, noSuchDirectory)+len(noSuchDirectory):]) + output := fmt.Sprintf(`{"ImportPath": %q,"Incomplete": true,"Error": {"Pos": "","Err": %q}}`, + abspath, strings.Trim(stderr.String(), "\n")) + return bytes.NewBufferString(output), nil + } + + // Workaround for #29280: go list -e has incorrect behavior when an ad-hoc package doesn't exist. + // Note that the error message we look for in this case is different that the one looked for above. + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), "no such file or directory") { + output := fmt.Sprintf(`{"ImportPath": "command-line-arguments","Incomplete": true,"Error": {"Pos": "","Err": %q}}`, + strings.Trim(stderr.String(), "\n")) + return bytes.NewBufferString(output), nil + } + + // Workaround for #34273. go list -e with GO111MODULE=on has incorrect behavior when listing a + // directory outside any module. + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), "outside available modules") { + output := fmt.Sprintf(`{"ImportPath": %q,"Incomplete": true,"Error": {"Pos": "","Err": %q}}`, + // TODO(matloob): command-line-arguments isn't correct here. + "command-line-arguments", strings.Trim(stderr.String(), "\n")) + return bytes.NewBufferString(output), nil + } + + // Another variation of the previous error + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), "outside module root") { + output := fmt.Sprintf(`{"ImportPath": %q,"Incomplete": true,"Error": {"Pos": "","Err": %q}}`, + // TODO(matloob): command-line-arguments isn't correct here. + "command-line-arguments", strings.Trim(stderr.String(), "\n")) + return bytes.NewBufferString(output), nil + } + + // Workaround for an instance of golang.org/issue/26755: go list -e will return a non-zero exit + // status if there's a dependency on a package that doesn't exist. But it should return + // a zero exit status and set an error on that package. + if len(stderr.String()) > 0 && strings.Contains(stderr.String(), "no Go files in") { + // Don't clobber stdout if `go list` actually returned something. + if len(stdout.String()) > 0 { + return stdout, nil + } + // try to extract package name from string + stderrStr := stderr.String() + var importPath string + colon := strings.Index(stderrStr, ":") + if colon > 0 && strings.HasPrefix(stderrStr, "go build ") { + importPath = stderrStr[len("go build "):colon] + } + output := fmt.Sprintf(`{"ImportPath": %q,"Incomplete": true,"Error": {"Pos": "","Err": %q}}`, + importPath, strings.Trim(stderrStr, "\n")) + return bytes.NewBufferString(output), nil + } + + // Export mode entails a build. + // If that build fails, errors appear on stderr + // (despite the -e flag) and the Export field is blank. + // Do not fail in that case. + // The same is true if an ad-hoc package given to go list doesn't exist. + // TODO(matloob): Remove these once we can depend on go list to exit with a zero status with -e even when + // packages don't exist or a build fails. + if !usesExportData(cfg) && !containsGoFile(args) { + return nil, friendlyErr + } + } + return stdout, nil +} + +// OverlayJSON is the format overlay files are expected to be in. +// The Replace map maps from overlaid paths to replacement paths: +// the Go command will forward all reads trying to open +// each overlaid path to its replacement path, or consider the overlaid +// path not to exist if the replacement path is empty. +// +// From golang/go#39958. +type OverlayJSON struct { + Replace map[string]string `json:"replace,omitempty"` +} + +// writeOverlays writes out files for go list's -overlay flag, as described +// above. +func (state *golistState) writeOverlays() (filename string, cleanup func(), err error) { + // Do nothing if there are no overlays in the config. + if len(state.cfg.Overlay) == 0 { + return "", func() {}, nil + } + dir, err := ioutil.TempDir("", "gopackages-*") + if err != nil { + return "", nil, err + } + // The caller must clean up this directory, unless this function returns an + // error. + cleanup = func() { + os.RemoveAll(dir) + } + defer func() { + if err != nil { + cleanup() + } + }() + overlays := map[string]string{} + for k, v := range state.cfg.Overlay { + // Create a unique filename for the overlaid files, to avoid + // creating nested directories. + noSeparator := strings.Join(strings.Split(filepath.ToSlash(k), "/"), "") + f, err := ioutil.TempFile(dir, fmt.Sprintf("*-%s", noSeparator)) + if err != nil { + return "", func() {}, err + } + if _, err := f.Write(v); err != nil { + return "", func() {}, err + } + if err := f.Close(); err != nil { + return "", func() {}, err + } + overlays[k] = f.Name() + } + b, err := json.Marshal(OverlayJSON{Replace: overlays}) + if err != nil { + return "", func() {}, err + } + // Write out the overlay file that contains the filepath mappings. + filename = filepath.Join(dir, "overlay.json") + if err := ioutil.WriteFile(filename, b, 0665); err != nil { + return "", func() {}, err + } + return filename, cleanup, nil +} + +func containsGoFile(s []string) bool { + for _, f := range s { + if strings.HasSuffix(f, ".go") { + return true + } + } + return false +} + +func cmdDebugStr(cmd *exec.Cmd) string { + env := make(map[string]string) + for _, kv := range cmd.Env { + split := strings.SplitN(kv, "=", 2) + k, v := split[0], split[1] + env[k] = v + } + + var args []string + for _, arg := range cmd.Args { + quoted := strconv.Quote(arg) + if quoted[1:len(quoted)-1] != arg || strings.Contains(arg, " ") { + args = append(args, quoted) + } else { + args = append(args, arg) + } + } + return fmt.Sprintf("GOROOT=%v GOPATH=%v GO111MODULE=%v GOPROXY=%v PWD=%v %v", env["GOROOT"], env["GOPATH"], env["GO111MODULE"], env["GOPROXY"], env["PWD"], strings.Join(args, " ")) +} diff --git a/vendor/golang.org/x/tools/go/packages/packages.go b/vendor/golang.org/x/tools/go/packages/packages.go new file mode 100644 index 0000000000..124a6fe143 --- /dev/null +++ b/vendor/golang.org/x/tools/go/packages/packages.go @@ -0,0 +1,1334 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packages + +// See doc.go for package documentation and implementation notes. + +import ( + "context" + "encoding/json" + "fmt" + "go/ast" + "go/parser" + "go/scanner" + "go/token" + "go/types" + "io" + "io/ioutil" + "log" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/internal/gocommand" + "golang.org/x/tools/internal/packagesinternal" + "golang.org/x/tools/internal/typeparams" + "golang.org/x/tools/internal/typesinternal" +) + +// A LoadMode controls the amount of detail to return when loading. +// The bits below can be combined to specify which fields should be +// filled in the result packages. +// The zero value is a special case, equivalent to combining +// the NeedName, NeedFiles, and NeedCompiledGoFiles bits. +// ID and Errors (if present) will always be filled. +// Load may return more information than requested. +type LoadMode int + +const ( + // NeedName adds Name and PkgPath. + NeedName LoadMode = 1 << iota + + // NeedFiles adds GoFiles and OtherFiles. + NeedFiles + + // NeedCompiledGoFiles adds CompiledGoFiles. + NeedCompiledGoFiles + + // NeedImports adds Imports. If NeedDeps is not set, the Imports field will contain + // "placeholder" Packages with only the ID set. + NeedImports + + // NeedDeps adds the fields requested by the LoadMode in the packages in Imports. + NeedDeps + + // NeedExportFile adds ExportFile. + NeedExportFile + + // NeedTypes adds Types, Fset, and IllTyped. + NeedTypes + + // NeedSyntax adds Syntax. + NeedSyntax + + // NeedTypesInfo adds TypesInfo. + NeedTypesInfo + + // NeedTypesSizes adds TypesSizes. + NeedTypesSizes + + // needInternalDepsErrors adds the internal deps errors field for use by gopls. + needInternalDepsErrors + + // needInternalForTest adds the internal forTest field. + // Tests must also be set on the context for this field to be populated. + needInternalForTest + + // typecheckCgo enables full support for type checking cgo. Requires Go 1.15+. + // Modifies CompiledGoFiles and Types, and has no effect on its own. + typecheckCgo + + // NeedModule adds Module. + NeedModule + + // NeedEmbedFiles adds EmbedFiles. + NeedEmbedFiles + + // NeedEmbedPatterns adds EmbedPatterns. + NeedEmbedPatterns +) + +const ( + // Deprecated: LoadFiles exists for historical compatibility + // and should not be used. Please directly specify the needed fields using the Need values. + LoadFiles = NeedName | NeedFiles | NeedCompiledGoFiles + + // Deprecated: LoadImports exists for historical compatibility + // and should not be used. Please directly specify the needed fields using the Need values. + LoadImports = LoadFiles | NeedImports + + // Deprecated: LoadTypes exists for historical compatibility + // and should not be used. Please directly specify the needed fields using the Need values. + LoadTypes = LoadImports | NeedTypes | NeedTypesSizes + + // Deprecated: LoadSyntax exists for historical compatibility + // and should not be used. Please directly specify the needed fields using the Need values. + LoadSyntax = LoadTypes | NeedSyntax | NeedTypesInfo + + // Deprecated: LoadAllSyntax exists for historical compatibility + // and should not be used. Please directly specify the needed fields using the Need values. + LoadAllSyntax = LoadSyntax | NeedDeps + + // Deprecated: NeedExportsFile is a historical misspelling of NeedExportFile. + NeedExportsFile = NeedExportFile +) + +// A Config specifies details about how packages should be loaded. +// The zero value is a valid configuration. +// Calls to Load do not modify this struct. +type Config struct { + // Mode controls the level of information returned for each package. + Mode LoadMode + + // Context specifies the context for the load operation. + // If the context is cancelled, the loader may stop early + // and return an ErrCancelled error. + // If Context is nil, the load cannot be cancelled. + Context context.Context + + // Logf is the logger for the config. + // If the user provides a logger, debug logging is enabled. + // If the GOPACKAGESDEBUG environment variable is set to true, + // but the logger is nil, default to log.Printf. + Logf func(format string, args ...interface{}) + + // Dir is the directory in which to run the build system's query tool + // that provides information about the packages. + // If Dir is empty, the tool is run in the current directory. + Dir string + + // Env is the environment to use when invoking the build system's query tool. + // If Env is nil, the current environment is used. + // As in os/exec's Cmd, only the last value in the slice for + // each environment key is used. To specify the setting of only + // a few variables, append to the current environment, as in: + // + // opt.Env = append(os.Environ(), "GOOS=plan9", "GOARCH=386") + // + Env []string + + // gocmdRunner guards go command calls from concurrency errors. + gocmdRunner *gocommand.Runner + + // BuildFlags is a list of command-line flags to be passed through to + // the build system's query tool. + BuildFlags []string + + // modFile will be used for -modfile in go command invocations. + modFile string + + // modFlag will be used for -modfile in go command invocations. + modFlag string + + // Fset provides source position information for syntax trees and types. + // If Fset is nil, Load will use a new fileset, but preserve Fset's value. + Fset *token.FileSet + + // ParseFile is called to read and parse each file + // when preparing a package's type-checked syntax tree. + // It must be safe to call ParseFile simultaneously from multiple goroutines. + // If ParseFile is nil, the loader will uses parser.ParseFile. + // + // ParseFile should parse the source from src and use filename only for + // recording position information. + // + // An application may supply a custom implementation of ParseFile + // to change the effective file contents or the behavior of the parser, + // or to modify the syntax tree. For example, selectively eliminating + // unwanted function bodies can significantly accelerate type checking. + ParseFile func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) + + // If Tests is set, the loader includes not just the packages + // matching a particular pattern but also any related test packages, + // including test-only variants of the package and the test executable. + // + // For example, when using the go command, loading "fmt" with Tests=true + // returns four packages, with IDs "fmt" (the standard package), + // "fmt [fmt.test]" (the package as compiled for the test), + // "fmt_test" (the test functions from source files in package fmt_test), + // and "fmt.test" (the test binary). + // + // In build systems with explicit names for tests, + // setting Tests may have no effect. + Tests bool + + // Overlay provides a mapping of absolute file paths to file contents. + // If the file with the given path already exists, the parser will use the + // alternative file contents provided by the map. + // + // Overlays provide incomplete support for when a given file doesn't + // already exist on disk. See the package doc above for more details. + Overlay map[string][]byte +} + +// driver is the type for functions that query the build system for the +// packages named by the patterns. +type driver func(cfg *Config, patterns ...string) (*driverResponse, error) + +// driverResponse contains the results for a driver query. +type driverResponse struct { + // NotHandled is returned if the request can't be handled by the current + // driver. If an external driver returns a response with NotHandled, the + // rest of the driverResponse is ignored, and go/packages will fallback + // to the next driver. If go/packages is extended in the future to support + // lists of multiple drivers, go/packages will fall back to the next driver. + NotHandled bool + + // Compiler and Arch are the arguments pass of types.SizesFor + // to get a types.Sizes to use when type checking. + Compiler string + Arch string + + // Roots is the set of package IDs that make up the root packages. + // We have to encode this separately because when we encode a single package + // we cannot know if it is one of the roots as that requires knowledge of the + // graph it is part of. + Roots []string `json:",omitempty"` + + // Packages is the full set of packages in the graph. + // The packages are not connected into a graph. + // The Imports if populated will be stubs that only have their ID set. + // Imports will be connected and then type and syntax information added in a + // later pass (see refine). + Packages []*Package + + // GoVersion is the minor version number used by the driver + // (e.g. the go command on the PATH) when selecting .go files. + // Zero means unknown. + GoVersion int +} + +// Load loads and returns the Go packages named by the given patterns. +// +// Config specifies loading options; +// nil behaves the same as an empty Config. +// +// Load returns an error if any of the patterns was invalid +// as defined by the underlying build system. +// It may return an empty list of packages without an error, +// for instance for an empty expansion of a valid wildcard. +// Errors associated with a particular package are recorded in the +// corresponding Package's Errors list, and do not cause Load to +// return an error. Clients may need to handle such errors before +// proceeding with further analysis. The PrintErrors function is +// provided for convenient display of all errors. +func Load(cfg *Config, patterns ...string) ([]*Package, error) { + l := newLoader(cfg) + response, err := defaultDriver(&l.Config, patterns...) + if err != nil { + return nil, err + } + l.sizes = types.SizesFor(response.Compiler, response.Arch) + return l.refine(response) +} + +// defaultDriver is a driver that implements go/packages' fallback behavior. +// It will try to request to an external driver, if one exists. If there's +// no external driver, or the driver returns a response with NotHandled set, +// defaultDriver will fall back to the go list driver. +func defaultDriver(cfg *Config, patterns ...string) (*driverResponse, error) { + driver := findExternalDriver(cfg) + if driver == nil { + driver = goListDriver + } + response, err := driver(cfg, patterns...) + if err != nil { + return response, err + } else if response.NotHandled { + return goListDriver(cfg, patterns...) + } + return response, nil +} + +// A Package describes a loaded Go package. +type Package struct { + // ID is a unique identifier for a package, + // in a syntax provided by the underlying build system. + // + // Because the syntax varies based on the build system, + // clients should treat IDs as opaque and not attempt to + // interpret them. + ID string + + // Name is the package name as it appears in the package source code. + Name string + + // PkgPath is the package path as used by the go/types package. + PkgPath string + + // Errors contains any errors encountered querying the metadata + // of the package, or while parsing or type-checking its files. + Errors []Error + + // TypeErrors contains the subset of errors produced during type checking. + TypeErrors []types.Error + + // GoFiles lists the absolute file paths of the package's Go source files. + // It may include files that should not be compiled, for example because + // they contain non-matching build tags, are documentary pseudo-files such as + // unsafe/unsafe.go or builtin/builtin.go, or are subject to cgo preprocessing. + GoFiles []string + + // CompiledGoFiles lists the absolute file paths of the package's source + // files that are suitable for type checking. + // This may differ from GoFiles if files are processed before compilation. + CompiledGoFiles []string + + // OtherFiles lists the absolute file paths of the package's non-Go source files, + // including assembly, C, C++, Fortran, Objective-C, SWIG, and so on. + OtherFiles []string + + // EmbedFiles lists the absolute file paths of the package's files + // embedded with go:embed. + EmbedFiles []string + + // EmbedPatterns lists the absolute file patterns of the package's + // files embedded with go:embed. + EmbedPatterns []string + + // IgnoredFiles lists source files that are not part of the package + // using the current build configuration but that might be part of + // the package using other build configurations. + IgnoredFiles []string + + // ExportFile is the absolute path to a file containing type + // information for the package as provided by the build system. + ExportFile string + + // Imports maps import paths appearing in the package's Go source files + // to corresponding loaded Packages. + Imports map[string]*Package + + // Types provides type information for the package. + // The NeedTypes LoadMode bit sets this field for packages matching the + // patterns; type information for dependencies may be missing or incomplete, + // unless NeedDeps and NeedImports are also set. + Types *types.Package + + // Fset provides position information for Types, TypesInfo, and Syntax. + // It is set only when Types is set. + Fset *token.FileSet + + // IllTyped indicates whether the package or any dependency contains errors. + // It is set only when Types is set. + IllTyped bool + + // Syntax is the package's syntax trees, for the files listed in CompiledGoFiles. + // + // The NeedSyntax LoadMode bit populates this field for packages matching the patterns. + // If NeedDeps and NeedImports are also set, this field will also be populated + // for dependencies. + // + // Syntax is kept in the same order as CompiledGoFiles, with the caveat that nils are + // removed. If parsing returned nil, Syntax may be shorter than CompiledGoFiles. + Syntax []*ast.File + + // TypesInfo provides type information about the package's syntax trees. + // It is set only when Syntax is set. + TypesInfo *types.Info + + // TypesSizes provides the effective size function for types in TypesInfo. + TypesSizes types.Sizes + + // forTest is the package under test, if any. + forTest string + + // depsErrors is the DepsErrors field from the go list response, if any. + depsErrors []*packagesinternal.PackageError + + // module is the module information for the package if it exists. + Module *Module +} + +// Module provides module information for a package. +type Module struct { + Path string // module path + Version string // module version + Replace *Module // replaced by this module + Time *time.Time // time version was created + Main bool // is this the main module? + Indirect bool // is this module only an indirect dependency of main module? + Dir string // directory holding files for this module, if any + GoMod string // path to go.mod file used when loading this module, if any + GoVersion string // go version used in module + Error *ModuleError // error loading module +} + +// ModuleError holds errors loading a module. +type ModuleError struct { + Err string // the error itself +} + +func init() { + packagesinternal.GetForTest = func(p interface{}) string { + return p.(*Package).forTest + } + packagesinternal.GetDepsErrors = func(p interface{}) []*packagesinternal.PackageError { + return p.(*Package).depsErrors + } + packagesinternal.GetGoCmdRunner = func(config interface{}) *gocommand.Runner { + return config.(*Config).gocmdRunner + } + packagesinternal.SetGoCmdRunner = func(config interface{}, runner *gocommand.Runner) { + config.(*Config).gocmdRunner = runner + } + packagesinternal.SetModFile = func(config interface{}, value string) { + config.(*Config).modFile = value + } + packagesinternal.SetModFlag = func(config interface{}, value string) { + config.(*Config).modFlag = value + } + packagesinternal.TypecheckCgo = int(typecheckCgo) + packagesinternal.DepsErrors = int(needInternalDepsErrors) + packagesinternal.ForTest = int(needInternalForTest) +} + +// An Error describes a problem with a package's metadata, syntax, or types. +type Error struct { + Pos string // "file:line:col" or "file:line" or "" or "-" + Msg string + Kind ErrorKind +} + +// ErrorKind describes the source of the error, allowing the user to +// differentiate between errors generated by the driver, the parser, or the +// type-checker. +type ErrorKind int + +const ( + UnknownError ErrorKind = iota + ListError + ParseError + TypeError +) + +func (err Error) Error() string { + pos := err.Pos + if pos == "" { + pos = "-" // like token.Position{}.String() + } + return pos + ": " + err.Msg +} + +// flatPackage is the JSON form of Package +// It drops all the type and syntax fields, and transforms the Imports +// +// TODO(adonovan): identify this struct with Package, effectively +// publishing the JSON protocol. +type flatPackage struct { + ID string + Name string `json:",omitempty"` + PkgPath string `json:",omitempty"` + Errors []Error `json:",omitempty"` + GoFiles []string `json:",omitempty"` + CompiledGoFiles []string `json:",omitempty"` + OtherFiles []string `json:",omitempty"` + EmbedFiles []string `json:",omitempty"` + EmbedPatterns []string `json:",omitempty"` + IgnoredFiles []string `json:",omitempty"` + ExportFile string `json:",omitempty"` + Imports map[string]string `json:",omitempty"` +} + +// MarshalJSON returns the Package in its JSON form. +// For the most part, the structure fields are written out unmodified, and +// the type and syntax fields are skipped. +// The imports are written out as just a map of path to package id. +// The errors are written using a custom type that tries to preserve the +// structure of error types we know about. +// +// This method exists to enable support for additional build systems. It is +// not intended for use by clients of the API and we may change the format. +func (p *Package) MarshalJSON() ([]byte, error) { + flat := &flatPackage{ + ID: p.ID, + Name: p.Name, + PkgPath: p.PkgPath, + Errors: p.Errors, + GoFiles: p.GoFiles, + CompiledGoFiles: p.CompiledGoFiles, + OtherFiles: p.OtherFiles, + EmbedFiles: p.EmbedFiles, + EmbedPatterns: p.EmbedPatterns, + IgnoredFiles: p.IgnoredFiles, + ExportFile: p.ExportFile, + } + if len(p.Imports) > 0 { + flat.Imports = make(map[string]string, len(p.Imports)) + for path, ipkg := range p.Imports { + flat.Imports[path] = ipkg.ID + } + } + return json.Marshal(flat) +} + +// UnmarshalJSON reads in a Package from its JSON format. +// See MarshalJSON for details about the format accepted. +func (p *Package) UnmarshalJSON(b []byte) error { + flat := &flatPackage{} + if err := json.Unmarshal(b, &flat); err != nil { + return err + } + *p = Package{ + ID: flat.ID, + Name: flat.Name, + PkgPath: flat.PkgPath, + Errors: flat.Errors, + GoFiles: flat.GoFiles, + CompiledGoFiles: flat.CompiledGoFiles, + OtherFiles: flat.OtherFiles, + EmbedFiles: flat.EmbedFiles, + EmbedPatterns: flat.EmbedPatterns, + ExportFile: flat.ExportFile, + } + if len(flat.Imports) > 0 { + p.Imports = make(map[string]*Package, len(flat.Imports)) + for path, id := range flat.Imports { + p.Imports[path] = &Package{ID: id} + } + } + return nil +} + +func (p *Package) String() string { return p.ID } + +// loaderPackage augments Package with state used during the loading phase +type loaderPackage struct { + *Package + importErrors map[string]error // maps each bad import to its error + loadOnce sync.Once + color uint8 // for cycle detection + needsrc bool // load from source (Mode >= LoadTypes) + needtypes bool // type information is either requested or depended on + initial bool // package was matched by a pattern + goVersion int // minor version number of go command on PATH +} + +// loader holds the working state of a single call to load. +type loader struct { + pkgs map[string]*loaderPackage + Config + sizes types.Sizes + parseCache map[string]*parseValue + parseCacheMu sync.Mutex + exportMu sync.Mutex // enforces mutual exclusion of exportdata operations + + // Config.Mode contains the implied mode (see impliedLoadMode). + // Implied mode contains all the fields we need the data for. + // In requestedMode there are the actually requested fields. + // We'll zero them out before returning packages to the user. + // This makes it easier for us to get the conditions where + // we need certain modes right. + requestedMode LoadMode +} + +type parseValue struct { + f *ast.File + err error + ready chan struct{} +} + +func newLoader(cfg *Config) *loader { + ld := &loader{ + parseCache: map[string]*parseValue{}, + } + if cfg != nil { + ld.Config = *cfg + // If the user has provided a logger, use it. + ld.Config.Logf = cfg.Logf + } + if ld.Config.Logf == nil { + // If the GOPACKAGESDEBUG environment variable is set to true, + // but the user has not provided a logger, default to log.Printf. + if debug { + ld.Config.Logf = log.Printf + } else { + ld.Config.Logf = func(format string, args ...interface{}) {} + } + } + if ld.Config.Mode == 0 { + ld.Config.Mode = NeedName | NeedFiles | NeedCompiledGoFiles // Preserve zero behavior of Mode for backwards compatibility. + } + if ld.Config.Env == nil { + ld.Config.Env = os.Environ() + } + if ld.Config.gocmdRunner == nil { + ld.Config.gocmdRunner = &gocommand.Runner{} + } + if ld.Context == nil { + ld.Context = context.Background() + } + if ld.Dir == "" { + if dir, err := os.Getwd(); err == nil { + ld.Dir = dir + } + } + + // Save the actually requested fields. We'll zero them out before returning packages to the user. + ld.requestedMode = ld.Mode + ld.Mode = impliedLoadMode(ld.Mode) + + if ld.Mode&NeedTypes != 0 || ld.Mode&NeedSyntax != 0 { + if ld.Fset == nil { + ld.Fset = token.NewFileSet() + } + + // ParseFile is required even in LoadTypes mode + // because we load source if export data is missing. + if ld.ParseFile == nil { + ld.ParseFile = func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + const mode = parser.AllErrors | parser.ParseComments + return parser.ParseFile(fset, filename, src, mode) + } + } + } + + return ld +} + +// refine connects the supplied packages into a graph and then adds type +// and syntax information as requested by the LoadMode. +func (ld *loader) refine(response *driverResponse) ([]*Package, error) { + roots := response.Roots + rootMap := make(map[string]int, len(roots)) + for i, root := range roots { + rootMap[root] = i + } + ld.pkgs = make(map[string]*loaderPackage) + // first pass, fixup and build the map and roots + var initial = make([]*loaderPackage, len(roots)) + for _, pkg := range response.Packages { + rootIndex := -1 + if i, found := rootMap[pkg.ID]; found { + rootIndex = i + } + + // Overlays can invalidate export data. + // TODO(matloob): make this check fine-grained based on dependencies on overlaid files + exportDataInvalid := len(ld.Overlay) > 0 || pkg.ExportFile == "" && pkg.PkgPath != "unsafe" + // This package needs type information if the caller requested types and the package is + // either a root, or it's a non-root and the user requested dependencies ... + needtypes := (ld.Mode&NeedTypes|NeedTypesInfo != 0 && (rootIndex >= 0 || ld.Mode&NeedDeps != 0)) + // This package needs source if the call requested source (or types info, which implies source) + // and the package is either a root, or itas a non- root and the user requested dependencies... + needsrc := ((ld.Mode&(NeedSyntax|NeedTypesInfo) != 0 && (rootIndex >= 0 || ld.Mode&NeedDeps != 0)) || + // ... or if we need types and the exportData is invalid. We fall back to (incompletely) + // typechecking packages from source if they fail to compile. + (ld.Mode&(NeedTypes|NeedTypesInfo) != 0 && exportDataInvalid)) && pkg.PkgPath != "unsafe" + lpkg := &loaderPackage{ + Package: pkg, + needtypes: needtypes, + needsrc: needsrc, + goVersion: response.GoVersion, + } + ld.pkgs[lpkg.ID] = lpkg + if rootIndex >= 0 { + initial[rootIndex] = lpkg + lpkg.initial = true + } + } + for i, root := range roots { + if initial[i] == nil { + return nil, fmt.Errorf("root package %v is missing", root) + } + } + + // Materialize the import graph. + + const ( + white = 0 // new + grey = 1 // in progress + black = 2 // complete + ) + + // visit traverses the import graph, depth-first, + // and materializes the graph as Packages.Imports. + // + // Valid imports are saved in the Packages.Import map. + // Invalid imports (cycles and missing nodes) are saved in the importErrors map. + // Thus, even in the presence of both kinds of errors, the Import graph remains a DAG. + // + // visit returns whether the package needs src or has a transitive + // dependency on a package that does. These are the only packages + // for which we load source code. + var stack []*loaderPackage + var visit func(lpkg *loaderPackage) bool + var srcPkgs []*loaderPackage + visit = func(lpkg *loaderPackage) bool { + switch lpkg.color { + case black: + return lpkg.needsrc + case grey: + panic("internal error: grey node") + } + lpkg.color = grey + stack = append(stack, lpkg) // push + stubs := lpkg.Imports // the structure form has only stubs with the ID in the Imports + // If NeedImports isn't set, the imports fields will all be zeroed out. + if ld.Mode&NeedImports != 0 { + lpkg.Imports = make(map[string]*Package, len(stubs)) + for importPath, ipkg := range stubs { + var importErr error + imp := ld.pkgs[ipkg.ID] + if imp == nil { + // (includes package "C" when DisableCgo) + importErr = fmt.Errorf("missing package: %q", ipkg.ID) + } else if imp.color == grey { + importErr = fmt.Errorf("import cycle: %s", stack) + } + if importErr != nil { + if lpkg.importErrors == nil { + lpkg.importErrors = make(map[string]error) + } + lpkg.importErrors[importPath] = importErr + continue + } + + if visit(imp) { + lpkg.needsrc = true + } + lpkg.Imports[importPath] = imp.Package + } + } + if lpkg.needsrc { + srcPkgs = append(srcPkgs, lpkg) + } + if ld.Mode&NeedTypesSizes != 0 { + lpkg.TypesSizes = ld.sizes + } + stack = stack[:len(stack)-1] // pop + lpkg.color = black + + return lpkg.needsrc + } + + if ld.Mode&NeedImports == 0 { + // We do this to drop the stub import packages that we are not even going to try to resolve. + for _, lpkg := range initial { + lpkg.Imports = nil + } + } else { + // For each initial package, create its import DAG. + for _, lpkg := range initial { + visit(lpkg) + } + } + if ld.Mode&NeedImports != 0 && ld.Mode&NeedTypes != 0 { + for _, lpkg := range srcPkgs { + // Complete type information is required for the + // immediate dependencies of each source package. + for _, ipkg := range lpkg.Imports { + imp := ld.pkgs[ipkg.ID] + imp.needtypes = true + } + } + } + // Load type data and syntax if needed, starting at + // the initial packages (roots of the import DAG). + if ld.Mode&NeedTypes != 0 || ld.Mode&NeedSyntax != 0 { + var wg sync.WaitGroup + for _, lpkg := range initial { + wg.Add(1) + go func(lpkg *loaderPackage) { + ld.loadRecursive(lpkg) + wg.Done() + }(lpkg) + } + wg.Wait() + } + + result := make([]*Package, len(initial)) + for i, lpkg := range initial { + result[i] = lpkg.Package + } + for i := range ld.pkgs { + // Clear all unrequested fields, + // to catch programs that use more than they request. + if ld.requestedMode&NeedName == 0 { + ld.pkgs[i].Name = "" + ld.pkgs[i].PkgPath = "" + } + if ld.requestedMode&NeedFiles == 0 { + ld.pkgs[i].GoFiles = nil + ld.pkgs[i].OtherFiles = nil + ld.pkgs[i].IgnoredFiles = nil + } + if ld.requestedMode&NeedEmbedFiles == 0 { + ld.pkgs[i].EmbedFiles = nil + } + if ld.requestedMode&NeedEmbedPatterns == 0 { + ld.pkgs[i].EmbedPatterns = nil + } + if ld.requestedMode&NeedCompiledGoFiles == 0 { + ld.pkgs[i].CompiledGoFiles = nil + } + if ld.requestedMode&NeedImports == 0 { + ld.pkgs[i].Imports = nil + } + if ld.requestedMode&NeedExportFile == 0 { + ld.pkgs[i].ExportFile = "" + } + if ld.requestedMode&NeedTypes == 0 { + ld.pkgs[i].Types = nil + ld.pkgs[i].Fset = nil + ld.pkgs[i].IllTyped = false + } + if ld.requestedMode&NeedSyntax == 0 { + ld.pkgs[i].Syntax = nil + } + if ld.requestedMode&NeedTypesInfo == 0 { + ld.pkgs[i].TypesInfo = nil + } + if ld.requestedMode&NeedTypesSizes == 0 { + ld.pkgs[i].TypesSizes = nil + } + if ld.requestedMode&NeedModule == 0 { + ld.pkgs[i].Module = nil + } + } + + return result, nil +} + +// loadRecursive loads the specified package and its dependencies, +// recursively, in parallel, in topological order. +// It is atomic and idempotent. +// Precondition: ld.Mode&NeedTypes. +func (ld *loader) loadRecursive(lpkg *loaderPackage) { + lpkg.loadOnce.Do(func() { + // Load the direct dependencies, in parallel. + var wg sync.WaitGroup + for _, ipkg := range lpkg.Imports { + imp := ld.pkgs[ipkg.ID] + wg.Add(1) + go func(imp *loaderPackage) { + ld.loadRecursive(imp) + wg.Done() + }(imp) + } + wg.Wait() + ld.loadPackage(lpkg) + }) +} + +// loadPackage loads the specified package. +// It must be called only once per Package, +// after immediate dependencies are loaded. +// Precondition: ld.Mode & NeedTypes. +func (ld *loader) loadPackage(lpkg *loaderPackage) { + if lpkg.PkgPath == "unsafe" { + // Fill in the blanks to avoid surprises. + lpkg.Types = types.Unsafe + lpkg.Fset = ld.Fset + lpkg.Syntax = []*ast.File{} + lpkg.TypesInfo = new(types.Info) + lpkg.TypesSizes = ld.sizes + return + } + + // Call NewPackage directly with explicit name. + // This avoids skew between golist and go/types when the files' + // package declarations are inconsistent. + lpkg.Types = types.NewPackage(lpkg.PkgPath, lpkg.Name) + lpkg.Fset = ld.Fset + + // Subtle: we populate all Types fields with an empty Package + // before loading export data so that export data processing + // never has to create a types.Package for an indirect dependency, + // which would then require that such created packages be explicitly + // inserted back into the Import graph as a final step after export data loading. + // (Hence this return is after the Types assignment.) + // The Diamond test exercises this case. + if !lpkg.needtypes && !lpkg.needsrc { + return + } + if !lpkg.needsrc { + if err := ld.loadFromExportData(lpkg); err != nil { + lpkg.Errors = append(lpkg.Errors, Error{ + Pos: "-", + Msg: err.Error(), + Kind: UnknownError, // e.g. can't find/open/parse export data + }) + } + return // not a source package, don't get syntax trees + } + + appendError := func(err error) { + // Convert various error types into the one true Error. + var errs []Error + switch err := err.(type) { + case Error: + // from driver + errs = append(errs, err) + + case *os.PathError: + // from parser + errs = append(errs, Error{ + Pos: err.Path + ":1", + Msg: err.Err.Error(), + Kind: ParseError, + }) + + case scanner.ErrorList: + // from parser + for _, err := range err { + errs = append(errs, Error{ + Pos: err.Pos.String(), + Msg: err.Msg, + Kind: ParseError, + }) + } + + case types.Error: + // from type checker + lpkg.TypeErrors = append(lpkg.TypeErrors, err) + errs = append(errs, Error{ + Pos: err.Fset.Position(err.Pos).String(), + Msg: err.Msg, + Kind: TypeError, + }) + + default: + // unexpected impoverished error from parser? + errs = append(errs, Error{ + Pos: "-", + Msg: err.Error(), + Kind: UnknownError, + }) + + // If you see this error message, please file a bug. + log.Printf("internal error: error %q (%T) without position", err, err) + } + + lpkg.Errors = append(lpkg.Errors, errs...) + } + + // If the go command on the PATH is newer than the runtime, + // then the go/{scanner,ast,parser,types} packages from the + // standard library may be unable to process the files + // selected by go list. + // + // There is currently no way to downgrade the effective + // version of the go command (see issue 52078), so we proceed + // with the newer go command but, in case of parse or type + // errors, we emit an additional diagnostic. + // + // See: + // - golang.org/issue/52078 (flag to set release tags) + // - golang.org/issue/50825 (gopls legacy version support) + // - golang.org/issue/55883 (go/packages confusing error) + // + // Should we assert a hard minimum of (currently) go1.16 here? + var runtimeVersion int + if _, err := fmt.Sscanf(runtime.Version(), "go1.%d", &runtimeVersion); err == nil && runtimeVersion < lpkg.goVersion { + defer func() { + if len(lpkg.Errors) > 0 { + appendError(Error{ + Pos: "-", + Msg: fmt.Sprintf("This application uses version go1.%d of the source-processing packages but runs version go1.%d of 'go list'. It may fail to process source files that rely on newer language features. If so, rebuild the application using a newer version of Go.", runtimeVersion, lpkg.goVersion), + Kind: UnknownError, + }) + } + }() + } + + if ld.Config.Mode&NeedTypes != 0 && len(lpkg.CompiledGoFiles) == 0 && lpkg.ExportFile != "" { + // The config requested loading sources and types, but sources are missing. + // Add an error to the package and fall back to loading from export data. + appendError(Error{"-", fmt.Sprintf("sources missing for package %s", lpkg.ID), ParseError}) + _ = ld.loadFromExportData(lpkg) // ignore any secondary errors + + return // can't get syntax trees for this package + } + + files, errs := ld.parseFiles(lpkg.CompiledGoFiles) + for _, err := range errs { + appendError(err) + } + + lpkg.Syntax = files + if ld.Config.Mode&NeedTypes == 0 { + return + } + + lpkg.TypesInfo = &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + typeparams.InitInstanceInfo(lpkg.TypesInfo) + lpkg.TypesSizes = ld.sizes + + importer := importerFunc(func(path string) (*types.Package, error) { + if path == "unsafe" { + return types.Unsafe, nil + } + + // The imports map is keyed by import path. + ipkg := lpkg.Imports[path] + if ipkg == nil { + if err := lpkg.importErrors[path]; err != nil { + return nil, err + } + // There was skew between the metadata and the + // import declarations, likely due to an edit + // race, or because the ParseFile feature was + // used to supply alternative file contents. + return nil, fmt.Errorf("no metadata for %s", path) + } + + if ipkg.Types != nil && ipkg.Types.Complete() { + return ipkg.Types, nil + } + log.Fatalf("internal error: package %q without types was imported from %q", path, lpkg) + panic("unreachable") + }) + + // type-check + tc := &types.Config{ + Importer: importer, + + // Type-check bodies of functions only in initial packages. + // Example: for import graph A->B->C and initial packages {A,C}, + // we can ignore function bodies in B. + IgnoreFuncBodies: ld.Mode&NeedDeps == 0 && !lpkg.initial, + + Error: appendError, + Sizes: ld.sizes, + } + if lpkg.Module != nil && lpkg.Module.GoVersion != "" { + typesinternal.SetGoVersion(tc, "go"+lpkg.Module.GoVersion) + } + if (ld.Mode & typecheckCgo) != 0 { + if !typesinternal.SetUsesCgo(tc) { + appendError(Error{ + Msg: "typecheckCgo requires Go 1.15+", + Kind: ListError, + }) + return + } + } + types.NewChecker(tc, ld.Fset, lpkg.Types, lpkg.TypesInfo).Files(lpkg.Syntax) + + lpkg.importErrors = nil // no longer needed + + // If !Cgo, the type-checker uses FakeImportC mode, so + // it doesn't invoke the importer for import "C", + // nor report an error for the import, + // or for any undefined C.f reference. + // We must detect this explicitly and correctly + // mark the package as IllTyped (by reporting an error). + // TODO(adonovan): if these errors are annoying, + // we could just set IllTyped quietly. + if tc.FakeImportC { + outer: + for _, f := range lpkg.Syntax { + for _, imp := range f.Imports { + if imp.Path.Value == `"C"` { + err := types.Error{Fset: ld.Fset, Pos: imp.Pos(), Msg: `import "C" ignored`} + appendError(err) + break outer + } + } + } + } + + // Record accumulated errors. + illTyped := len(lpkg.Errors) > 0 + if !illTyped { + for _, imp := range lpkg.Imports { + if imp.IllTyped { + illTyped = true + break + } + } + } + lpkg.IllTyped = illTyped +} + +// An importFunc is an implementation of the single-method +// types.Importer interface based on a function value. +type importerFunc func(path string) (*types.Package, error) + +func (f importerFunc) Import(path string) (*types.Package, error) { return f(path) } + +// We use a counting semaphore to limit +// the number of parallel I/O calls per process. +var ioLimit = make(chan bool, 20) + +func (ld *loader) parseFile(filename string) (*ast.File, error) { + ld.parseCacheMu.Lock() + v, ok := ld.parseCache[filename] + if ok { + // cache hit + ld.parseCacheMu.Unlock() + <-v.ready + } else { + // cache miss + v = &parseValue{ready: make(chan struct{})} + ld.parseCache[filename] = v + ld.parseCacheMu.Unlock() + + var src []byte + for f, contents := range ld.Config.Overlay { + if sameFile(f, filename) { + src = contents + } + } + var err error + if src == nil { + ioLimit <- true // wait + src, err = ioutil.ReadFile(filename) + <-ioLimit // signal + } + if err != nil { + v.err = err + } else { + v.f, v.err = ld.ParseFile(ld.Fset, filename, src) + } + + close(v.ready) + } + return v.f, v.err +} + +// parseFiles reads and parses the Go source files and returns the ASTs +// of the ones that could be at least partially parsed, along with a +// list of I/O and parse errors encountered. +// +// Because files are scanned in parallel, the token.Pos +// positions of the resulting ast.Files are not ordered. +func (ld *loader) parseFiles(filenames []string) ([]*ast.File, []error) { + var wg sync.WaitGroup + n := len(filenames) + parsed := make([]*ast.File, n) + errors := make([]error, n) + for i, file := range filenames { + if ld.Config.Context.Err() != nil { + parsed[i] = nil + errors[i] = ld.Config.Context.Err() + continue + } + wg.Add(1) + go func(i int, filename string) { + parsed[i], errors[i] = ld.parseFile(filename) + wg.Done() + }(i, file) + } + wg.Wait() + + // Eliminate nils, preserving order. + var o int + for _, f := range parsed { + if f != nil { + parsed[o] = f + o++ + } + } + parsed = parsed[:o] + + o = 0 + for _, err := range errors { + if err != nil { + errors[o] = err + o++ + } + } + errors = errors[:o] + + return parsed, errors +} + +// sameFile returns true if x and y have the same basename and denote +// the same file. +func sameFile(x, y string) bool { + if x == y { + // It could be the case that y doesn't exist. + // For instance, it may be an overlay file that + // hasn't been written to disk. To handle that case + // let x == y through. (We added the exact absolute path + // string to the CompiledGoFiles list, so the unwritten + // overlay case implies x==y.) + return true + } + if strings.EqualFold(filepath.Base(x), filepath.Base(y)) { // (optimisation) + if xi, err := os.Stat(x); err == nil { + if yi, err := os.Stat(y); err == nil { + return os.SameFile(xi, yi) + } + } + } + return false +} + +// loadFromExportData ensures that type information is present for the specified +// package, loading it from an export data file on the first request. +// On success it sets lpkg.Types to a new Package. +func (ld *loader) loadFromExportData(lpkg *loaderPackage) error { + if lpkg.PkgPath == "" { + log.Fatalf("internal error: Package %s has no PkgPath", lpkg) + } + + // Because gcexportdata.Read has the potential to create or + // modify the types.Package for each node in the transitive + // closure of dependencies of lpkg, all exportdata operations + // must be sequential. (Finer-grained locking would require + // changes to the gcexportdata API.) + // + // The exportMu lock guards the lpkg.Types field and the + // types.Package it points to, for each loaderPackage in the graph. + // + // Not all accesses to Package.Pkg need to be protected by exportMu: + // graph ordering ensures that direct dependencies of source + // packages are fully loaded before the importer reads their Pkg field. + ld.exportMu.Lock() + defer ld.exportMu.Unlock() + + if tpkg := lpkg.Types; tpkg != nil && tpkg.Complete() { + return nil // cache hit + } + + lpkg.IllTyped = true // fail safe + + if lpkg.ExportFile == "" { + // Errors while building export data will have been printed to stderr. + return fmt.Errorf("no export data file") + } + f, err := os.Open(lpkg.ExportFile) + if err != nil { + return err + } + defer f.Close() + + // Read gc export data. + // + // We don't currently support gccgo export data because all + // underlying workspaces use the gc toolchain. (Even build + // systems that support gccgo don't use it for workspace + // queries.) + r, err := gcexportdata.NewReader(f) + if err != nil { + return fmt.Errorf("reading %s: %v", lpkg.ExportFile, err) + } + + // Build the view. + // + // The gcexportdata machinery has no concept of package ID. + // It identifies packages by their PkgPath, which although not + // globally unique is unique within the scope of one invocation + // of the linker, type-checker, or gcexportdata. + // + // So, we must build a PkgPath-keyed view of the global + // (conceptually ID-keyed) cache of packages and pass it to + // gcexportdata. The view must contain every existing + // package that might possibly be mentioned by the + // current package---its transitive closure. + // + // In loadPackage, we unconditionally create a types.Package for + // each dependency so that export data loading does not + // create new ones. + // + // TODO(adonovan): it would be simpler and more efficient + // if the export data machinery invoked a callback to + // get-or-create a package instead of a map. + // + view := make(map[string]*types.Package) // view seen by gcexportdata + seen := make(map[*loaderPackage]bool) // all visited packages + var visit func(pkgs map[string]*Package) + visit = func(pkgs map[string]*Package) { + for _, p := range pkgs { + lpkg := ld.pkgs[p.ID] + if !seen[lpkg] { + seen[lpkg] = true + view[lpkg.PkgPath] = lpkg.Types + visit(lpkg.Imports) + } + } + } + visit(lpkg.Imports) + + viewLen := len(view) + 1 // adding the self package + // Parse the export data. + // (May modify incomplete packages in view but not create new ones.) + tpkg, err := gcexportdata.Read(r, ld.Fset, view, lpkg.PkgPath) + if err != nil { + return fmt.Errorf("reading %s: %v", lpkg.ExportFile, err) + } + if _, ok := view["go.shape"]; ok { + // Account for the pseudopackage "go.shape" that gets + // created by generic code. + viewLen++ + } + if viewLen != len(view) { + log.Panicf("golang.org/x/tools/go/packages: unexpected new packages during load of %s", lpkg.PkgPath) + } + + lpkg.Types = tpkg + lpkg.IllTyped = false + return nil +} + +// impliedLoadMode returns loadMode with its dependencies. +func impliedLoadMode(loadMode LoadMode) LoadMode { + if loadMode&(NeedDeps|NeedTypes|NeedTypesInfo) != 0 { + // All these things require knowing the import graph. + loadMode |= NeedImports + } + + return loadMode +} + +func usesExportData(cfg *Config) bool { + return cfg.Mode&NeedExportFile != 0 || cfg.Mode&NeedTypes != 0 && cfg.Mode&NeedDeps == 0 +} + +var _ interface{} = io.Discard // assert build toolchain is go1.16 or later diff --git a/vendor/golang.org/x/tools/internal/gcimporter/gcimporter.go b/vendor/golang.org/x/tools/internal/gcimporter/gcimporter.go new file mode 100644 index 0000000000..b1223713b9 --- /dev/null +++ b/vendor/golang.org/x/tools/internal/gcimporter/gcimporter.go @@ -0,0 +1,274 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file is a reduced copy of $GOROOT/src/go/internal/gcimporter/gcimporter.go. + +// Package gcimporter provides various functions for reading +// gc-generated object files that can be used to implement the +// Importer interface defined by the Go 1.5 standard library package. +// +// The encoding is deterministic: if the encoder is applied twice to +// the same types.Package data structure, both encodings are equal. +// This property may be important to avoid spurious changes in +// applications such as build systems. +// +// However, the encoder is not necessarily idempotent. Importing an +// exported package may yield a types.Package that, while it +// represents the same set of Go types as the original, may differ in +// the details of its internal representation. Because of these +// differences, re-encoding the imported package may yield a +// different, but equally valid, encoding of the package. +package gcimporter // import "golang.org/x/tools/internal/gcimporter" + +import ( + "bufio" + "bytes" + "fmt" + "go/build" + "go/token" + "go/types" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" +) + +const ( + // Enable debug during development: it adds some additional checks, and + // prevents errors from being recovered. + debug = false + + // If trace is set, debugging output is printed to std out. + trace = false +) + +var exportMap sync.Map // package dir → func() (string, bool) + +// lookupGorootExport returns the location of the export data +// (normally found in the build cache, but located in GOROOT/pkg +// in prior Go releases) for the package located in pkgDir. +// +// (We use the package's directory instead of its import path +// mainly to simplify handling of the packages in src/vendor +// and cmd/vendor.) +func lookupGorootExport(pkgDir string) (string, bool) { + f, ok := exportMap.Load(pkgDir) + if !ok { + var ( + listOnce sync.Once + exportPath string + ) + f, _ = exportMap.LoadOrStore(pkgDir, func() (string, bool) { + listOnce.Do(func() { + cmd := exec.Command("go", "list", "-export", "-f", "{{.Export}}", pkgDir) + cmd.Dir = build.Default.GOROOT + var output []byte + output, err := cmd.Output() + if err != nil { + return + } + + exports := strings.Split(string(bytes.TrimSpace(output)), "\n") + if len(exports) != 1 { + return + } + + exportPath = exports[0] + }) + + return exportPath, exportPath != "" + }) + } + + return f.(func() (string, bool))() +} + +var pkgExts = [...]string{".a", ".o"} + +// FindPkg returns the filename and unique package id for an import +// path based on package information provided by build.Import (using +// the build.Default build.Context). A relative srcDir is interpreted +// relative to the current working directory. +// If no file was found, an empty filename is returned. +func FindPkg(path, srcDir string) (filename, id string) { + if path == "" { + return + } + + var noext string + switch { + default: + // "x" -> "$GOPATH/pkg/$GOOS_$GOARCH/x.ext", "x" + // Don't require the source files to be present. + if abs, err := filepath.Abs(srcDir); err == nil { // see issue 14282 + srcDir = abs + } + bp, _ := build.Import(path, srcDir, build.FindOnly|build.AllowBinary) + if bp.PkgObj == "" { + var ok bool + if bp.Goroot && bp.Dir != "" { + filename, ok = lookupGorootExport(bp.Dir) + } + if !ok { + id = path // make sure we have an id to print in error message + return + } + } else { + noext = strings.TrimSuffix(bp.PkgObj, ".a") + id = bp.ImportPath + } + + case build.IsLocalImport(path): + // "./x" -> "/this/directory/x.ext", "/this/directory/x" + noext = filepath.Join(srcDir, path) + id = noext + + case filepath.IsAbs(path): + // for completeness only - go/build.Import + // does not support absolute imports + // "/x" -> "/x.ext", "/x" + noext = path + id = path + } + + if false { // for debugging + if path != id { + fmt.Printf("%s -> %s\n", path, id) + } + } + + if filename != "" { + if f, err := os.Stat(filename); err == nil && !f.IsDir() { + return + } + } + + // try extensions + for _, ext := range pkgExts { + filename = noext + ext + if f, err := os.Stat(filename); err == nil && !f.IsDir() { + return + } + } + + filename = "" // not found + return +} + +// Import imports a gc-generated package given its import path and srcDir, adds +// the corresponding package object to the packages map, and returns the object. +// The packages map must contain all packages already imported. +func Import(packages map[string]*types.Package, path, srcDir string, lookup func(path string) (io.ReadCloser, error)) (pkg *types.Package, err error) { + var rc io.ReadCloser + var filename, id string + if lookup != nil { + // With custom lookup specified, assume that caller has + // converted path to a canonical import path for use in the map. + if path == "unsafe" { + return types.Unsafe, nil + } + id = path + + // No need to re-import if the package was imported completely before. + if pkg = packages[id]; pkg != nil && pkg.Complete() { + return + } + f, err := lookup(path) + if err != nil { + return nil, err + } + rc = f + } else { + filename, id = FindPkg(path, srcDir) + if filename == "" { + if path == "unsafe" { + return types.Unsafe, nil + } + return nil, fmt.Errorf("can't find import: %q", id) + } + + // no need to re-import if the package was imported completely before + if pkg = packages[id]; pkg != nil && pkg.Complete() { + return + } + + // open file + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + // add file name to error + err = fmt.Errorf("%s: %v", filename, err) + } + }() + rc = f + } + defer rc.Close() + + var hdr string + var size int64 + buf := bufio.NewReader(rc) + if hdr, size, err = FindExportData(buf); err != nil { + return + } + + switch hdr { + case "$$B\n": + var data []byte + data, err = ioutil.ReadAll(buf) + if err != nil { + break + } + + // TODO(gri): allow clients of go/importer to provide a FileSet. + // Or, define a new standard go/types/gcexportdata package. + fset := token.NewFileSet() + + // Select appropriate importer. + if len(data) > 0 { + switch data[0] { + case 'v', 'c', 'd': // binary, till go1.10 + return nil, fmt.Errorf("binary (%c) import format is no longer supported", data[0]) + + case 'i': // indexed, till go1.19 + _, pkg, err := IImportData(fset, packages, data[1:], id) + return pkg, err + + case 'u': // unified, from go1.20 + _, pkg, err := UImportData(fset, packages, data[1:size], id) + return pkg, err + + default: + l := len(data) + if l > 10 { + l = 10 + } + return nil, fmt.Errorf("unexpected export data with prefix %q for path %s", string(data[:l]), id) + } + } + + default: + err = fmt.Errorf("unknown export data header: %q", hdr) + } + + return +} + +func deref(typ types.Type) types.Type { + if p, _ := typ.(*types.Pointer); p != nil { + return p.Elem() + } + return typ +} + +type byPath []*types.Package + +func (a byPath) Len() int { return len(a) } +func (a byPath) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byPath) Less(i, j int) bool { return a[i].Path() < a[j].Path() } diff --git a/vendor/gonum.org/v1/gonum/AUTHORS b/vendor/gonum.org/v1/gonum/AUTHORS new file mode 100644 index 0000000000..7d49714ab2 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/AUTHORS @@ -0,0 +1,125 @@ +# This is the official list of Gonum authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS files. +# See the latter for an explanation. + +# Names should be added to this file as +# Name or Organization +# The email address is not required for organizations. + +# Please keep the list sorted. + +Alexander Egurnov +Andrei Blinnikov +antichris +Bill Gray +Bill Noon +Brendan Tracey +Brent Pedersen +Chad Kunde +Chan Kwan Yin +Chih-Wei Chang +Chong-Yeol Nah +Chris Tessum +Christophe Meessen +Christopher Waldon +Clayton Northey +Dan Kortschak +Daniel Fireman +Dario Heinisch +David Kleiven +David Samborski +Davor Kapsa +DeepMind Technologies +Delaney Gillilan +Dezmond Goff +Dong-hee Na +Dustin Spicuzza +Egon Elbre +Ekaterina Efimova +Ethan Burns +Evert Lammerts +Evgeny Savinov +Fabian Wickborn +Facundo Gaich +Fazlul Shahriar +Francesc Campoy +Google Inc +Gustaf Johansson +Hossein Zolfi +Iakov Davydov +Igor Mikushkin +Iskander Sharipov +Jalem Raj Rohit +James Bell +James Bowman +James Holmes <32bitkid@gmail.com> +Janne Snabb +Jeremy Atkinson +Jinesi Yelizati +Jonas Kahler +Jonas Schulze +Jonathan J Lawlor +Jonathan Reiter +Jonathan Schroeder +Joost van Amersfoort +Joseph Watson +Josh Wilson +Julien Roland +Kai Trukenmüller +Kent English +Kevin C. Zimmerman +Kirill Motkov +Konstantin Shaposhnikov +Leonid Kneller +Lyron Winderbaum +Marco Leogrande +Mark Canning +Mark Skilbeck +Martin Diz +Matthew Connelly +Matthieu Di Mercurio +Max Halford +Maxim Sergeev +Microsoft Corporation +MinJae Kwon +Nathan Edwards +Nick Potts +Nils Wogatzky +Olivier Wulveryck +Or Rikon +Patricio Whittingslow +Patrick DeVivo +Pontus Melke +Renee French +Rishi Desai +Robin Eklind +Roger Welin +Rondall Jones +Sam Zaydel +Samuel Kelemen +Saran Ahluwalia +Scott Holden +Scott Kiesel +Sebastien Binet +Shawn Smith +Sintela Ltd +source{d} +Spencer Lyon +Steve McCoy +Taesu Pyo +Takeshi Yoneda +Tamir Hyman +The University of Adelaide +The University of Minnesota +The University of Washington +Thomas Berg +Tobin Harding +Valentin Deleplace +Vincent Thiery +Vladimír Chalupecký +Will Tekulve +Yasuhiro Matsumoto +Yevgeniy Vahlis +Yucheng Zhu +Yunomi +Zoe Juozapaitis diff --git a/vendor/gonum.org/v1/gonum/CONTRIBUTORS b/vendor/gonum.org/v1/gonum/CONTRIBUTORS new file mode 100644 index 0000000000..b8bef3e337 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/CONTRIBUTORS @@ -0,0 +1,128 @@ +# This is the official list of people who can contribute +# (and typically have contributed) code to the Gonum +# project. +# +# The AUTHORS file lists the copyright holders; this file +# lists people. For example, Google employees would be listed here +# but not in AUTHORS, because Google would hold the copyright. +# +# When adding J Random Contributor's name to this file, +# either J's name or J's organization's name should be +# added to the AUTHORS file. +# +# Names should be added to this file like so: +# Name +# +# Please keep the list sorted. + +Alexander Egurnov +Andrei Blinnikov +Andrew Brampton +antichris +Bill Gray +Bill Noon +Brendan Tracey +Brent Pedersen +Chad Kunde +Chan Kwan Yin +Chih-Wei Chang +Chong-Yeol Nah +Chris Tessum +Christophe Meessen +Christopher Waldon +Clayton Northey +Dan Kortschak +Dan Lorenc +Daniel Fireman +Dario Heinisch +David Kleiven +David Samborski +Davor Kapsa +Delaney Gillilan +Dezmond Goff +Dong-hee Na +Dustin Spicuzza +Egon Elbre +Ekaterina Efimova +Ethan Burns +Evert Lammerts +Evgeny Savinov +Fabian Wickborn +Facundo Gaich +Fazlul Shahriar +Francesc Campoy +Gustaf Johansson +Hossein Zolfi +Iakov Davydov +Igor Mikushkin +Iskander Sharipov +Jalem Raj Rohit +James Bell +James Bowman +James Holmes <32bitkid@gmail.com> +Janne Snabb +Jeremy Atkinson +Jinesi Yelizati +Jon Richards +Jonas Kahler +Jonas Schulze +Jonathan J Lawlor +Jonathan Reiter +Jonathan Schroeder +Joost van Amersfoort +Joseph Watson +Josh Wilson +Julien Roland +Kai Trukenmüller +Kent English +Kevin C. Zimmerman +Kirill Motkov +Konstantin Shaposhnikov +Leonid Kneller +Lyron Winderbaum +Marco Leogrande +Mark Canning +Mark Skilbeck +Martin Diz +Matthew Connelly +Matthieu Di Mercurio +Max Halford +Maxim Sergeev +MinJae Kwon +Nathan Edwards +Nick Potts +Nils Wogatzky +Olivier Wulveryck +Or Rikon +Patricio Whittingslow +Patrick DeVivo +Pontus Melke +Renee French +Rishi Desai +Robin Eklind +Roger Welin +Roman Werpachowski +Rondall Jones +Sam Zaydel +Samuel Kelemen +Saran Ahluwalia +Scott Holden +Scott Kiesel +Sebastien Binet +Shawn Smith +Spencer Lyon +Steve McCoy +Taesu Pyo +Takeshi Yoneda +Tamir Hyman +Thomas Berg +Tobin Harding +Valentin Deleplace +Vincent Thiery +Vladimír Chalupecký +Will Tekulve +Yasuhiro Matsumoto +Yevgeniy Vahlis +Yucheng Zhu +Yunomi +Zoe Juozapaitis diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dbdsqr.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dbdsqr.go new file mode 100644 index 0000000000..e9c055b32a --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dbdsqr.go @@ -0,0 +1,508 @@ +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack" +) + +// Dbdsqr performs a singular value decomposition of a real n×n bidiagonal matrix. +// +// The SVD of the bidiagonal matrix B is +// +// B = Q * S * Pᵀ +// +// where S is a diagonal matrix of singular values, Q is an orthogonal matrix of +// left singular vectors, and P is an orthogonal matrix of right singular vectors. +// +// Q and P are only computed if requested. If left singular vectors are requested, +// this routine returns U * Q instead of Q, and if right singular vectors are +// requested Pᵀ * VT is returned instead of Pᵀ. +// +// Frequently Dbdsqr is used in conjunction with Dgebrd which reduces a general +// matrix A into bidiagonal form. In this case, the SVD of A is +// +// A = (U * Q) * S * (Pᵀ * VT) +// +// This routine may also compute Qᵀ * C. +// +// d and e contain the elements of the bidiagonal matrix b. d must have length at +// least n, and e must have length at least n-1. Dbdsqr will panic if there is +// insufficient length. On exit, D contains the singular values of B in decreasing +// order. +// +// VT is a matrix of size n×ncvt whose elements are stored in vt. The elements +// of vt are modified to contain Pᵀ * VT on exit. VT is not used if ncvt == 0. +// +// U is a matrix of size nru×n whose elements are stored in u. The elements +// of u are modified to contain U * Q on exit. U is not used if nru == 0. +// +// C is a matrix of size n×ncc whose elements are stored in c. The elements +// of c are modified to contain Qᵀ * C on exit. C is not used if ncc == 0. +// +// work contains temporary storage and must have length at least 4*(n-1). Dbdsqr +// will panic if there is insufficient working memory. +// +// Dbdsqr returns whether the decomposition was successful. +// +// Dbdsqr is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) { + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case n < 0: + panic(nLT0) + case ncvt < 0: + panic(ncvtLT0) + case nru < 0: + panic(nruLT0) + case ncc < 0: + panic(nccLT0) + case ldvt < max(1, ncvt): + panic(badLdVT) + case (ldu < max(1, n) && nru > 0) || (ldu < 1 && nru == 0): + panic(badLdU) + case ldc < max(1, ncc): + panic(badLdC) + } + + // Quick return if possible. + if n == 0 { + return true + } + + if len(vt) < (n-1)*ldvt+ncvt && ncvt != 0 { + panic(shortVT) + } + if len(u) < (nru-1)*ldu+n && nru != 0 { + panic(shortU) + } + if len(c) < (n-1)*ldc+ncc && ncc != 0 { + panic(shortC) + } + if len(d) < n { + panic(shortD) + } + if len(e) < n-1 { + panic(shortE) + } + if len(work) < 4*(n-1) { + panic(shortWork) + } + + var info int + bi := blas64.Implementation() + const maxIter = 6 + + if n != 1 { + // If the singular vectors do not need to be computed, use qd algorithm. + if !(ncvt > 0 || nru > 0 || ncc > 0) { + info = impl.Dlasq1(n, d, e, work) + // If info is 2 dqds didn't finish, and so try to. + if info != 2 { + return info == 0 + } + } + nm1 := n - 1 + nm12 := nm1 + nm1 + nm13 := nm12 + nm1 + idir := 0 + + eps := dlamchE + unfl := dlamchS + lower := uplo == blas.Lower + var cs, sn, r float64 + if lower { + for i := 0; i < n-1; i++ { + cs, sn, r = impl.Dlartg(d[i], e[i]) + d[i] = r + e[i] = sn * d[i+1] + d[i+1] *= cs + work[i] = cs + work[nm1+i] = sn + } + if nru > 0 { + impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, n, work, work[n-1:], u, ldu) + } + if ncc > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, n, ncc, work, work[n-1:], c, ldc) + } + } + // Compute singular values to a relative accuracy of tol. If tol is negative + // the values will be computed to an absolute accuracy of math.Abs(tol) * norm(b) + tolmul := math.Max(10, math.Min(100, math.Pow(eps, -1.0/8))) + tol := tolmul * eps + var smax float64 + for i := 0; i < n; i++ { + smax = math.Max(smax, math.Abs(d[i])) + } + for i := 0; i < n-1; i++ { + smax = math.Max(smax, math.Abs(e[i])) + } + + var sminl float64 + var thresh float64 + if tol >= 0 { + sminoa := math.Abs(d[0]) + if sminoa != 0 { + mu := sminoa + for i := 1; i < n; i++ { + mu = math.Abs(d[i]) * (mu / (mu + math.Abs(e[i-1]))) + sminoa = math.Min(sminoa, mu) + if sminoa == 0 { + break + } + } + } + sminoa = sminoa / math.Sqrt(float64(n)) + thresh = math.Max(tol*sminoa, float64(maxIter*n*n)*unfl) + } else { + thresh = math.Max(math.Abs(tol)*smax, float64(maxIter*n*n)*unfl) + } + // Prepare for the main iteration loop for the singular values. + maxIt := maxIter * n * n + iter := 0 + oldl2 := -1 + oldm := -1 + // m points to the last element of unconverged part of matrix. + m := n + + Outer: + for m > 1 { + if iter > maxIt { + info = 0 + for i := 0; i < n-1; i++ { + if e[i] != 0 { + info++ + } + } + return info == 0 + } + // Find diagonal block of matrix to work on. + if tol < 0 && math.Abs(d[m-1]) <= thresh { + d[m-1] = 0 + } + smax = math.Abs(d[m-1]) + smin := smax + var l2 int + var broke bool + for l3 := 0; l3 < m-1; l3++ { + l2 = m - l3 - 2 + abss := math.Abs(d[l2]) + abse := math.Abs(e[l2]) + if tol < 0 && abss <= thresh { + d[l2] = 0 + } + if abse <= thresh { + broke = true + break + } + smin = math.Min(smin, abss) + smax = math.Max(math.Max(smax, abss), abse) + } + if broke { + e[l2] = 0 + if l2 == m-2 { + // Convergence of bottom singular value, return to top. + m-- + continue + } + l2++ + } else { + l2 = 0 + } + // e[ll] through e[m-2] are nonzero, e[ll-1] is zero + if l2 == m-2 { + // Handle 2×2 block separately. + var sinr, cosr, sinl, cosl float64 + d[m-1], d[m-2], sinr, cosr, sinl, cosl = impl.Dlasv2(d[m-2], e[m-2], d[m-1]) + e[m-2] = 0 + if ncvt > 0 { + bi.Drot(ncvt, vt[(m-2)*ldvt:], 1, vt[(m-1)*ldvt:], 1, cosr, sinr) + } + if nru > 0 { + bi.Drot(nru, u[m-2:], ldu, u[m-1:], ldu, cosl, sinl) + } + if ncc > 0 { + bi.Drot(ncc, c[(m-2)*ldc:], 1, c[(m-1)*ldc:], 1, cosl, sinl) + } + m -= 2 + continue + } + // If working on a new submatrix, choose shift direction from larger end + // diagonal element toward smaller. + if l2 > oldm-1 || m-1 < oldl2 { + if math.Abs(d[l2]) >= math.Abs(d[m-1]) { + idir = 1 + } else { + idir = 2 + } + } + // Apply convergence tests. + // TODO(btracey): There is a lot of similar looking code here. See + // if there is a better way to de-duplicate. + if idir == 1 { + // Run convergence test in forward direction. + // First apply standard test to bottom of matrix. + if math.Abs(e[m-2]) <= math.Abs(tol)*math.Abs(d[m-1]) || (tol < 0 && math.Abs(e[m-2]) <= thresh) { + e[m-2] = 0 + continue + } + if tol >= 0 { + // If relative accuracy desired, apply convergence criterion forward. + mu := math.Abs(d[l2]) + sminl = mu + for l3 := l2; l3 < m-1; l3++ { + if math.Abs(e[l3]) <= tol*mu { + e[l3] = 0 + continue Outer + } + mu = math.Abs(d[l3+1]) * (mu / (mu + math.Abs(e[l3]))) + sminl = math.Min(sminl, mu) + } + } + } else { + // Run convergence test in backward direction. + // First apply standard test to top of matrix. + if math.Abs(e[l2]) <= math.Abs(tol)*math.Abs(d[l2]) || (tol < 0 && math.Abs(e[l2]) <= thresh) { + e[l2] = 0 + continue + } + if tol >= 0 { + // If relative accuracy desired, apply convergence criterion backward. + mu := math.Abs(d[m-1]) + sminl = mu + for l3 := m - 2; l3 >= l2; l3-- { + if math.Abs(e[l3]) <= tol*mu { + e[l3] = 0 + continue Outer + } + mu = math.Abs(d[l3]) * (mu / (mu + math.Abs(e[l3]))) + sminl = math.Min(sminl, mu) + } + } + } + oldl2 = l2 + oldm = m + // Compute shift. First, test if shifting would ruin relative accuracy, + // and if so set the shift to zero. + var shift float64 + if tol >= 0 && float64(n)*tol*(sminl/smax) <= math.Max(eps, (1.0/100)*tol) { + shift = 0 + } else { + var sl2 float64 + if idir == 1 { + sl2 = math.Abs(d[l2]) + shift, _ = impl.Dlas2(d[m-2], e[m-2], d[m-1]) + } else { + sl2 = math.Abs(d[m-1]) + shift, _ = impl.Dlas2(d[l2], e[l2], d[l2+1]) + } + // Test if shift is negligible + if sl2 > 0 { + if (shift/sl2)*(shift/sl2) < eps { + shift = 0 + } + } + } + iter += m - l2 + 1 + // If no shift, do simplified QR iteration. + if shift == 0 { + if idir == 1 { + cs := 1.0 + oldcs := 1.0 + var sn, r, oldsn float64 + for i := l2; i < m-1; i++ { + cs, sn, r = impl.Dlartg(d[i]*cs, e[i]) + if i > l2 { + e[i-1] = oldsn * r + } + oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i+1]*sn) + work[i-l2] = cs + work[i-l2+nm1] = sn + work[i-l2+nm12] = oldcs + work[i-l2+nm13] = oldsn + } + h := d[m-1] * cs + d[m-1] = h * oldcs + e[m-2] = h * oldsn + if ncvt > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt) + } + if nru > 0 { + impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu) + } + if ncc > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc) + } + if math.Abs(e[m-2]) < thresh { + e[m-2] = 0 + } + } else { + cs := 1.0 + oldcs := 1.0 + var sn, r, oldsn float64 + for i := m - 1; i >= l2+1; i-- { + cs, sn, r = impl.Dlartg(d[i]*cs, e[i-1]) + if i < m-1 { + e[i] = oldsn * r + } + oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i-1]*sn) + work[i-l2-1] = cs + work[i-l2+nm1-1] = -sn + work[i-l2+nm12-1] = oldcs + work[i-l2+nm13-1] = -oldsn + } + h := d[l2] * cs + d[l2] = h * oldcs + e[l2] = h * oldsn + if ncvt > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt) + } + if nru > 0 { + impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu) + } + if ncc > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc) + } + if math.Abs(e[l2]) <= thresh { + e[l2] = 0 + } + } + } else { + // Use nonzero shift. + if idir == 1 { + // Chase bulge from top to bottom. Save cosines and sines for + // later singular vector updates. + f := (math.Abs(d[l2]) - shift) * (math.Copysign(1, d[l2]) + shift/d[l2]) + g := e[l2] + var cosl, sinl float64 + for i := l2; i < m-1; i++ { + cosr, sinr, r := impl.Dlartg(f, g) + if i > l2 { + e[i-1] = r + } + f = cosr*d[i] + sinr*e[i] + e[i] = cosr*e[i] - sinr*d[i] + g = sinr * d[i+1] + d[i+1] *= cosr + cosl, sinl, r = impl.Dlartg(f, g) + d[i] = r + f = cosl*e[i] + sinl*d[i+1] + d[i+1] = cosl*d[i+1] - sinl*e[i] + if i < m-2 { + g = sinl * e[i+1] + e[i+1] = cosl * e[i+1] + } + work[i-l2] = cosr + work[i-l2+nm1] = sinr + work[i-l2+nm12] = cosl + work[i-l2+nm13] = sinl + } + e[m-2] = f + if ncvt > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt) + } + if nru > 0 { + impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu) + } + if ncc > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc) + } + if math.Abs(e[m-2]) <= thresh { + e[m-2] = 0 + } + } else { + // Chase bulge from top to bottom. Save cosines and sines for + // later singular vector updates. + f := (math.Abs(d[m-1]) - shift) * (math.Copysign(1, d[m-1]) + shift/d[m-1]) + g := e[m-2] + for i := m - 1; i > l2; i-- { + cosr, sinr, r := impl.Dlartg(f, g) + if i < m-1 { + e[i] = r + } + f = cosr*d[i] + sinr*e[i-1] + e[i-1] = cosr*e[i-1] - sinr*d[i] + g = sinr * d[i-1] + d[i-1] *= cosr + cosl, sinl, r := impl.Dlartg(f, g) + d[i] = r + f = cosl*e[i-1] + sinl*d[i-1] + d[i-1] = cosl*d[i-1] - sinl*e[i-1] + if i > l2+1 { + g = sinl * e[i-2] + e[i-2] *= cosl + } + work[i-l2-1] = cosr + work[i-l2+nm1-1] = -sinr + work[i-l2+nm12-1] = cosl + work[i-l2+nm13-1] = -sinl + } + e[l2] = f + if math.Abs(e[l2]) <= thresh { + e[l2] = 0 + } + if ncvt > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt) + } + if nru > 0 { + impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu) + } + if ncc > 0 { + impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc) + } + } + } + } + } + + // All singular values converged, make them positive. + for i := 0; i < n; i++ { + if d[i] < 0 { + d[i] *= -1 + if ncvt > 0 { + bi.Dscal(ncvt, -1, vt[i*ldvt:], 1) + } + } + } + + // Sort the singular values in decreasing order. + for i := 0; i < n-1; i++ { + isub := 0 + smin := d[0] + for j := 1; j < n-i; j++ { + if d[j] <= smin { + isub = j + smin = d[j] + } + } + if isub != n-i { + // Swap singular values and vectors. + d[isub] = d[n-i-1] + d[n-i-1] = smin + if ncvt > 0 { + bi.Dswap(ncvt, vt[isub*ldvt:], 1, vt[(n-i-1)*ldvt:], 1) + } + if nru > 0 { + bi.Dswap(nru, u[isub:], ldu, u[n-i-1:], ldu) + } + if ncc > 0 { + bi.Dswap(ncc, c[isub*ldc:], 1, c[(n-i-1)*ldc:], 1) + } + } + } + info = 0 + for i := 0; i < n-1; i++ { + if e[i] != 0 { + info++ + } + } + return info == 0 +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dgecon.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dgecon.go new file mode 100644 index 0000000000..1d1ca586bb --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dgecon.go @@ -0,0 +1,92 @@ +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack" +) + +// Dgecon estimates the reciprocal of the condition number of the n×n matrix A +// given the LU decomposition of the matrix. The condition number computed may +// be based on the 1-norm or the ∞-norm. +// +// The slice a contains the result of the LU decomposition of A as computed by Dgetrf. +// +// anorm is the corresponding 1-norm or ∞-norm of the original matrix A. +// +// work is a temporary data slice of length at least 4*n and Dgecon will panic otherwise. +// +// iwork is a temporary data slice of length at least n and Dgecon will panic otherwise. +func (impl Implementation) Dgecon(norm lapack.MatrixNorm, n int, a []float64, lda int, anorm float64, work []float64, iwork []int) float64 { + switch { + case norm != lapack.MaxColumnSum && norm != lapack.MaxRowSum: + panic(badNorm) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + if n == 0 { + return 1 + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(work) < 4*n: + panic(shortWork) + case len(iwork) < n: + panic(shortIWork) + } + + // Quick return if possible. + if anorm == 0 { + return 0 + } + + bi := blas64.Implementation() + var rcond, ainvnm float64 + var kase int + var normin bool + isave := new([3]int) + onenrm := norm == lapack.MaxColumnSum + smlnum := dlamchS + kase1 := 2 + if onenrm { + kase1 = 1 + } + for { + ainvnm, kase = impl.Dlacn2(n, work[n:], work, iwork, ainvnm, kase, isave) + if kase == 0 { + if ainvnm != 0 { + rcond = (1 / ainvnm) / anorm + } + return rcond + } + var sl, su float64 + if kase == kase1 { + sl = impl.Dlatrs(blas.Lower, blas.NoTrans, blas.Unit, normin, n, a, lda, work, work[2*n:]) + su = impl.Dlatrs(blas.Upper, blas.NoTrans, blas.NonUnit, normin, n, a, lda, work, work[3*n:]) + } else { + su = impl.Dlatrs(blas.Upper, blas.Trans, blas.NonUnit, normin, n, a, lda, work, work[3*n:]) + sl = impl.Dlatrs(blas.Lower, blas.Trans, blas.Unit, normin, n, a, lda, work, work[2*n:]) + } + scale := sl * su + normin = true + if scale != 1 { + ix := bi.Idamax(n, work, 1) + if scale == 0 || scale < math.Abs(work[ix])*smlnum { + return rcond + } + impl.Drscl(n, scale, work, 1) + } + } +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dhseqr.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dhseqr.go new file mode 100644 index 0000000000..80fe19bb0b --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dhseqr.go @@ -0,0 +1,272 @@ +// Copyright ©2016 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/lapack" +) + +// Dhseqr computes the eigenvalues of an n×n Hessenberg matrix H and, +// optionally, the matrices T and Z from the Schur decomposition +// +// H = Z T Zᵀ, +// +// where T is an n×n upper quasi-triangular matrix (the Schur form), and Z is +// the n×n orthogonal matrix of Schur vectors. +// +// Optionally Z may be postmultiplied into an input orthogonal matrix Q so that +// this routine can give the Schur factorization of a matrix A which has been +// reduced to the Hessenberg form H by the orthogonal matrix Q: +// +// A = Q H Qᵀ = (QZ) T (QZ)ᵀ. +// +// If job == lapack.EigenvaluesOnly, only the eigenvalues will be computed. +// If job == lapack.EigenvaluesAndSchur, the eigenvalues and the Schur form T will +// be computed. +// For other values of job Dhseqr will panic. +// +// If compz == lapack.SchurNone, no Schur vectors will be computed and Z will not be +// referenced. +// If compz == lapack.SchurHess, on return Z will contain the matrix of Schur +// vectors of H. +// If compz == lapack.SchurOrig, on entry z is assumed to contain the orthogonal +// matrix Q that is the identity except for the submatrix +// Q[ilo:ihi+1,ilo:ihi+1]. On return z will be updated to the product Q*Z. +// +// ilo and ihi determine the block of H on which Dhseqr operates. It is assumed +// that H is already upper triangular in rows and columns [0:ilo] and [ihi+1:n], +// although it will be only checked that the block is isolated, that is, +// +// ilo == 0 or H[ilo,ilo-1] == 0, +// ihi == n-1 or H[ihi+1,ihi] == 0, +// +// and Dhseqr will panic otherwise. ilo and ihi are typically set by a previous +// call to Dgebal, otherwise they should be set to 0 and n-1, respectively. It +// must hold that +// +// 0 <= ilo <= ihi < n if n > 0, +// ilo == 0 and ihi == -1 if n == 0. +// +// wr and wi must have length n. +// +// work must have length at least lwork and lwork must be at least max(1,n) +// otherwise Dhseqr will panic. The minimum lwork delivers very good and +// sometimes optimal performance, although lwork as large as 11*n may be +// required. On return, work[0] will contain the optimal value of lwork. +// +// If lwork is -1, instead of performing Dhseqr, the function only estimates the +// optimal workspace size and stores it into work[0]. Neither h nor z are +// accessed. +// +// unconverged indicates whether Dhseqr computed all the eigenvalues. +// +// If unconverged == 0, all the eigenvalues have been computed and their real +// and imaginary parts will be stored on return in wr and wi, respectively. If +// two eigenvalues are computed as a complex conjugate pair, they are stored in +// consecutive elements of wr and wi, say the i-th and (i+1)th, with wi[i] > 0 +// and wi[i+1] < 0. +// +// If unconverged == 0 and job == lapack.EigenvaluesAndSchur, on return H will +// contain the upper quasi-triangular matrix T from the Schur decomposition (the +// Schur form). 2×2 diagonal blocks (corresponding to complex conjugate pairs of +// eigenvalues) will be returned in standard form, with +// +// H[i,i] == H[i+1,i+1], +// +// and +// +// H[i+1,i]*H[i,i+1] < 0. +// +// The eigenvalues will be stored in wr and wi in the same order as on the +// diagonal of the Schur form returned in H, with +// +// wr[i] = H[i,i], +// +// and, if H[i:i+2,i:i+2] is a 2×2 diagonal block, +// +// wi[i] = sqrt(-H[i+1,i]*H[i,i+1]), +// wi[i+1] = -wi[i]. +// +// If unconverged == 0 and job == lapack.EigenvaluesOnly, the contents of h +// on return is unspecified. +// +// If unconverged > 0, some eigenvalues have not converged, and the blocks +// [0:ilo] and [unconverged:n] of wr and wi will contain those eigenvalues which +// have been successfully computed. Failures are rare. +// +// If unconverged > 0 and job == lapack.EigenvaluesOnly, on return the +// remaining unconverged eigenvalues are the eigenvalues of the upper Hessenberg +// matrix H[ilo:unconverged,ilo:unconverged]. +// +// If unconverged > 0 and job == lapack.EigenvaluesAndSchur, then on +// return +// +// (initial H) U = U (final H), (*) +// +// where U is an orthogonal matrix. The final H is upper Hessenberg and +// H[unconverged:ihi+1,unconverged:ihi+1] is upper quasi-triangular. +// +// If unconverged > 0 and compz == lapack.SchurOrig, then on return +// +// (final Z) = (initial Z) U, +// +// where U is the orthogonal matrix in (*) regardless of the value of job. +// +// If unconverged > 0 and compz == lapack.SchurHess, then on return +// +// (final Z) = U, +// +// where U is the orthogonal matrix in (*) regardless of the value of job. +// +// References: +// +// [1] R. Byers. LAPACK 3.1 xHSEQR: Tuning and Implementation Notes on the +// Small Bulge Multi-Shift QR Algorithm with Aggressive Early Deflation. +// LAPACK Working Note 187 (2007) +// URL: http://www.netlib.org/lapack/lawnspdf/lawn187.pdf +// [2] K. Braman, R. Byers, R. Mathias. The Multishift QR Algorithm. Part I: +// Maintaining Well-Focused Shifts and Level 3 Performance. SIAM J. Matrix +// Anal. Appl. 23(4) (2002), pp. 929—947 +// URL: http://dx.doi.org/10.1137/S0895479801384573 +// [3] K. Braman, R. Byers, R. Mathias. The Multishift QR Algorithm. Part II: +// Aggressive Early Deflation. SIAM J. Matrix Anal. Appl. 23(4) (2002), pp. 948—973 +// URL: http://dx.doi.org/10.1137/S0895479801384585 +// +// Dhseqr is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dhseqr(job lapack.SchurJob, compz lapack.SchurComp, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64, z []float64, ldz int, work []float64, lwork int) (unconverged int) { + wantt := job == lapack.EigenvaluesAndSchur + wantz := compz == lapack.SchurHess || compz == lapack.SchurOrig + + switch { + case job != lapack.EigenvaluesOnly && job != lapack.EigenvaluesAndSchur: + panic(badSchurJob) + case compz != lapack.SchurNone && compz != lapack.SchurHess && compz != lapack.SchurOrig: + panic(badSchurComp) + case n < 0: + panic(nLT0) + case ilo < 0 || max(0, n-1) < ilo: + panic(badIlo) + case ihi < min(ilo, n-1) || n <= ihi: + panic(badIhi) + case ldh < max(1, n): + panic(badLdH) + case ldz < 1, wantz && ldz < n: + panic(badLdZ) + case lwork < max(1, n) && lwork != -1: + panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) + } + + // Quick return if possible. + if n == 0 { + work[0] = 1 + return 0 + } + + // Quick return in case of a workspace query. + if lwork == -1 { + impl.Dlaqr04(wantt, wantz, n, ilo, ihi, h, ldh, wr, wi, ilo, ihi, z, ldz, work, -1, 1) + work[0] = math.Max(float64(n), work[0]) + return 0 + } + + switch { + case len(h) < (n-1)*ldh+n: + panic(shortH) + case wantz && len(z) < (n-1)*ldz+n: + panic(shortZ) + case len(wr) < n: + panic(shortWr) + case len(wi) < n: + panic(shortWi) + } + + const ( + // Matrices of order ntiny or smaller must be processed by + // Dlahqr because of insufficient subdiagonal scratch space. + // This is a hard limit. + ntiny = 11 + + // nl is the size of a local workspace to help small matrices + // through a rare Dlahqr failure. nl > ntiny is required and + // nl <= nmin = Ilaenv(ispec=12,...) is recommended (the default + // value of nmin is 75). Using nl = 49 allows up to six + // simultaneous shifts and a 16×16 deflation window. + nl = 49 + ) + + // Copy eigenvalues isolated by Dgebal. + for i := 0; i < ilo; i++ { + wr[i] = h[i*ldh+i] + wi[i] = 0 + } + for i := ihi + 1; i < n; i++ { + wr[i] = h[i*ldh+i] + wi[i] = 0 + } + + // Initialize Z to identity matrix if requested. + if compz == lapack.SchurHess { + impl.Dlaset(blas.All, n, n, 0, 1, z, ldz) + } + + // Quick return if possible. + if ilo == ihi { + wr[ilo] = h[ilo*ldh+ilo] + wi[ilo] = 0 + return 0 + } + + // Dlahqr/Dlaqr04 crossover point. + nmin := impl.Ilaenv(12, "DHSEQR", string(job)+string(compz), n, ilo, ihi, lwork) + nmin = max(ntiny, nmin) + + if n > nmin { + // Dlaqr0 for big matrices. + unconverged = impl.Dlaqr04(wantt, wantz, n, ilo, ihi, h, ldh, wr[:ihi+1], wi[:ihi+1], + ilo, ihi, z, ldz, work, lwork, 1) + } else { + // Dlahqr for small matrices. + unconverged = impl.Dlahqr(wantt, wantz, n, ilo, ihi, h, ldh, wr[:ihi+1], wi[:ihi+1], + ilo, ihi, z, ldz) + if unconverged > 0 { + // A rare Dlahqr failure! Dlaqr04 sometimes succeeds + // when Dlahqr fails. + kbot := unconverged + if n >= nl { + // Larger matrices have enough subdiagonal + // scratch space to call Dlaqr04 directly. + unconverged = impl.Dlaqr04(wantt, wantz, n, ilo, kbot, h, ldh, + wr[:ihi+1], wi[:ihi+1], ilo, ihi, z, ldz, work, lwork, 1) + } else { + // Tiny matrices don't have enough subdiagonal + // scratch space to benefit from Dlaqr04. Hence, + // tiny matrices must be copied into a larger + // array before calling Dlaqr04. + var hl [nl * nl]float64 + impl.Dlacpy(blas.All, n, n, h, ldh, hl[:], nl) + impl.Dlaset(blas.All, nl, nl-n, 0, 0, hl[n:], nl) + var workl [nl]float64 + unconverged = impl.Dlaqr04(wantt, wantz, nl, ilo, kbot, hl[:], nl, + wr[:ihi+1], wi[:ihi+1], ilo, ihi, z, ldz, workl[:], nl, 1) + work[0] = workl[0] + if wantt || unconverged > 0 { + impl.Dlacpy(blas.All, n, n, hl[:], nl, h, ldh) + } + } + } + } + // Zero out under the first subdiagonal, if necessary. + if (wantt || unconverged > 0) && n > 2 { + impl.Dlaset(blas.Lower, n-2, n-2, 0, 0, h[2*ldh:], ldh) + } + + work[0] = math.Max(float64(n), work[0]) + return unconverged +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dlahqr.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dlahqr.go new file mode 100644 index 0000000000..13f2856015 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dlahqr.go @@ -0,0 +1,441 @@ +// Copyright ©2016 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas/blas64" +) + +// Dlahqr computes the eigenvalues and Schur factorization of a block of an n×n +// upper Hessenberg matrix H, using the double-shift/single-shift QR algorithm. +// +// h and ldh represent the matrix H. Dlahqr works primarily with the Hessenberg +// submatrix H[ilo:ihi+1,ilo:ihi+1], but applies transformations to all of H if +// wantt is true. It is assumed that H[ihi+1:n,ihi+1:n] is already upper +// quasi-triangular, although this is not checked. +// +// It must hold that +// +// 0 <= ilo <= max(0,ihi), and ihi < n, +// +// and that +// +// H[ilo,ilo-1] == 0, if ilo > 0, +// +// otherwise Dlahqr will panic. +// +// If unconverged is zero on return, wr[ilo:ihi+1] and wi[ilo:ihi+1] will contain +// respectively the real and imaginary parts of the computed eigenvalues ilo +// to ihi. If two eigenvalues are computed as a complex conjugate pair, they are +// stored in consecutive elements of wr and wi, say the i-th and (i+1)th, with +// wi[i] > 0 and wi[i+1] < 0. If wantt is true, the eigenvalues are stored in +// the same order as on the diagonal of the Schur form returned in H, with +// wr[i] = H[i,i], and, if H[i:i+2,i:i+2] is a 2×2 diagonal block, +// wi[i] = sqrt(abs(H[i+1,i]*H[i,i+1])) and wi[i+1] = -wi[i]. +// +// wr and wi must have length ihi+1. +// +// z and ldz represent an n×n matrix Z. If wantz is true, the transformations +// will be applied to the submatrix Z[iloz:ihiz+1,ilo:ihi+1] and it must hold that +// +// 0 <= iloz <= ilo, and ihi <= ihiz < n. +// +// If wantz is false, z is not referenced. +// +// unconverged indicates whether Dlahqr computed all the eigenvalues ilo to ihi +// in a total of 30 iterations per eigenvalue. +// +// If unconverged is zero, all the eigenvalues ilo to ihi have been computed and +// will be stored on return in wr[ilo:ihi+1] and wi[ilo:ihi+1]. +// +// If unconverged is zero and wantt is true, H[ilo:ihi+1,ilo:ihi+1] will be +// overwritten on return by upper quasi-triangular full Schur form with any +// 2×2 diagonal blocks in standard form. +// +// If unconverged is zero and if wantt is false, the contents of h on return is +// unspecified. +// +// If unconverged is positive, some eigenvalues have not converged, and +// wr[unconverged:ihi+1] and wi[unconverged:ihi+1] contain those eigenvalues +// which have been successfully computed. +// +// If unconverged is positive and wantt is true, then on return +// +// (initial H)*U = U*(final H), (*) +// +// where U is an orthogonal matrix. The final H is upper Hessenberg and +// H[unconverged:ihi+1,unconverged:ihi+1] is upper quasi-triangular. +// +// If unconverged is positive and wantt is false, on return the remaining +// unconverged eigenvalues are the eigenvalues of the upper Hessenberg matrix +// H[ilo:unconverged,ilo:unconverged]. +// +// If unconverged is positive and wantz is true, then on return +// +// (final Z) = (initial Z)*U, +// +// where U is the orthogonal matrix in (*) regardless of the value of wantt. +// +// Dlahqr is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dlahqr(wantt, wantz bool, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64, iloz, ihiz int, z []float64, ldz int) (unconverged int) { + switch { + case n < 0: + panic(nLT0) + case ilo < 0, max(0, ihi) < ilo: + panic(badIlo) + case ihi >= n: + panic(badIhi) + case ldh < max(1, n): + panic(badLdH) + case wantz && (iloz < 0 || ilo < iloz): + panic(badIloz) + case wantz && (ihiz < ihi || n <= ihiz): + panic(badIhiz) + case ldz < 1, wantz && ldz < n: + panic(badLdZ) + } + + // Quick return if possible. + if n == 0 { + return 0 + } + + switch { + case len(h) < (n-1)*ldh+n: + panic(shortH) + case len(wr) != ihi+1: + panic(shortWr) + case len(wi) != ihi+1: + panic(shortWi) + case wantz && len(z) < (n-1)*ldz+n: + panic(shortZ) + case ilo > 0 && h[ilo*ldh+ilo-1] != 0: + panic(notIsolated) + } + + if ilo == ihi { + wr[ilo] = h[ilo*ldh+ilo] + wi[ilo] = 0 + return 0 + } + + // Clear out the trash. + for j := ilo; j < ihi-2; j++ { + h[(j+2)*ldh+j] = 0 + h[(j+3)*ldh+j] = 0 + } + if ilo <= ihi-2 { + h[ihi*ldh+ihi-2] = 0 + } + + nh := ihi - ilo + 1 + nz := ihiz - iloz + 1 + + // Set machine-dependent constants for the stopping criterion. + ulp := dlamchP + smlnum := float64(nh) / ulp * dlamchS + + // i1 and i2 are the indices of the first row and last column of H to + // which transformations must be applied. If eigenvalues only are being + // computed, i1 and i2 are set inside the main loop. + var i1, i2 int + if wantt { + i1 = 0 + i2 = n - 1 + } + + itmax := 30 * max(10, nh) // Total number of QR iterations allowed. + + // The main loop begins here. i is the loop index and decreases from ihi + // to ilo in steps of 1 or 2. Each iteration of the loop works with the + // active submatrix in rows and columns l to i. Eigenvalues i+1 to ihi + // have already converged. Either l = ilo or H[l,l-1] is negligible so + // that the matrix splits. + bi := blas64.Implementation() + i := ihi + for i >= ilo { + l := ilo + + // Perform QR iterations on rows and columns ilo to i until a + // submatrix of order 1 or 2 splits off at the bottom because a + // subdiagonal element has become negligible. + converged := false + for its := 0; its <= itmax; its++ { + // Look for a single small subdiagonal element. + var k int + for k = i; k > l; k-- { + if math.Abs(h[k*ldh+k-1]) <= smlnum { + break + } + tst := math.Abs(h[(k-1)*ldh+k-1]) + math.Abs(h[k*ldh+k]) + if tst == 0 { + if k-2 >= ilo { + tst += math.Abs(h[(k-1)*ldh+k-2]) + } + if k+1 <= ihi { + tst += math.Abs(h[(k+1)*ldh+k]) + } + } + // The following is a conservative small + // subdiagonal deflation criterion due to Ahues + // & Tisseur (LAWN 122, 1997). It has better + // mathematical foundation and improves accuracy + // in some cases. + if math.Abs(h[k*ldh+k-1]) <= ulp*tst { + ab := math.Max(math.Abs(h[k*ldh+k-1]), math.Abs(h[(k-1)*ldh+k])) + ba := math.Min(math.Abs(h[k*ldh+k-1]), math.Abs(h[(k-1)*ldh+k])) + aa := math.Max(math.Abs(h[k*ldh+k]), math.Abs(h[(k-1)*ldh+k-1]-h[k*ldh+k])) + bb := math.Min(math.Abs(h[k*ldh+k]), math.Abs(h[(k-1)*ldh+k-1]-h[k*ldh+k])) + s := aa + ab + if ab/s*ba <= math.Max(smlnum, aa/s*bb*ulp) { + break + } + } + } + l = k + if l > ilo { + // H[l,l-1] is negligible. + h[l*ldh+l-1] = 0 + } + if l >= i-1 { + // Break the loop because a submatrix of order 1 + // or 2 has split off. + converged = true + break + } + + // Now the active submatrix is in rows and columns l to + // i. If eigenvalues only are being computed, only the + // active submatrix need be transformed. + if !wantt { + i1 = l + i2 = i + } + + const ( + dat1 = 3.0 + dat2 = -0.4375 + ) + var h11, h21, h12, h22 float64 + switch its { + case 10: // Exceptional shift. + s := math.Abs(h[(l+1)*ldh+l]) + math.Abs(h[(l+2)*ldh+l+1]) + h11 = dat1*s + h[l*ldh+l] + h12 = dat2 * s + h21 = s + h22 = h11 + case 20: // Exceptional shift. + s := math.Abs(h[i*ldh+i-1]) + math.Abs(h[(i-1)*ldh+i-2]) + h11 = dat1*s + h[i*ldh+i] + h12 = dat2 * s + h21 = s + h22 = h11 + default: // Prepare to use Francis' double shift (i.e., + // 2nd degree generalized Rayleigh quotient). + h11 = h[(i-1)*ldh+i-1] + h21 = h[i*ldh+i-1] + h12 = h[(i-1)*ldh+i] + h22 = h[i*ldh+i] + } + s := math.Abs(h11) + math.Abs(h12) + math.Abs(h21) + math.Abs(h22) + var ( + rt1r, rt1i float64 + rt2r, rt2i float64 + ) + if s != 0 { + h11 /= s + h21 /= s + h12 /= s + h22 /= s + tr := (h11 + h22) / 2 + det := (h11-tr)*(h22-tr) - h12*h21 + rtdisc := math.Sqrt(math.Abs(det)) + if det >= 0 { + // Complex conjugate shifts. + rt1r = tr * s + rt2r = rt1r + rt1i = rtdisc * s + rt2i = -rt1i + } else { + // Real shifts (use only one of them). + rt1r = tr + rtdisc + rt2r = tr - rtdisc + if math.Abs(rt1r-h22) <= math.Abs(rt2r-h22) { + rt1r *= s + rt2r = rt1r + } else { + rt2r *= s + rt1r = rt2r + } + rt1i = 0 + rt2i = 0 + } + } + + // Look for two consecutive small subdiagonal elements. + var m int + var v [3]float64 + for m = i - 2; m >= l; m-- { + // Determine the effect of starting the + // double-shift QR iteration at row m, and see + // if this would make H[m,m-1] negligible. The + // following uses scaling to avoid overflows and + // most underflows. + h21s := h[(m+1)*ldh+m] + s := math.Abs(h[m*ldh+m]-rt2r) + math.Abs(rt2i) + math.Abs(h21s) + h21s /= s + v[0] = h21s*h[m*ldh+m+1] + (h[m*ldh+m]-rt1r)*((h[m*ldh+m]-rt2r)/s) - rt2i/s*rt1i + v[1] = h21s * (h[m*ldh+m] + h[(m+1)*ldh+m+1] - rt1r - rt2r) + v[2] = h21s * h[(m+2)*ldh+m+1] + s = math.Abs(v[0]) + math.Abs(v[1]) + math.Abs(v[2]) + v[0] /= s + v[1] /= s + v[2] /= s + if m == l { + break + } + dsum := math.Abs(h[(m-1)*ldh+m-1]) + math.Abs(h[m*ldh+m]) + math.Abs(h[(m+1)*ldh+m+1]) + if math.Abs(h[m*ldh+m-1])*(math.Abs(v[1])+math.Abs(v[2])) <= ulp*math.Abs(v[0])*dsum { + break + } + } + + // Double-shift QR step. + for k := m; k < i; k++ { + // The first iteration of this loop determines a + // reflection G from the vector V and applies it + // from left and right to H, thus creating a + // non-zero bulge below the subdiagonal. + // + // Each subsequent iteration determines a + // reflection G to restore the Hessenberg form + // in the (k-1)th column, and thus chases the + // bulge one step toward the bottom of the + // active submatrix. nr is the order of G. + + nr := min(3, i-k+1) + if k > m { + bi.Dcopy(nr, h[k*ldh+k-1:], ldh, v[:], 1) + } + var t0 float64 + v[0], t0 = impl.Dlarfg(nr, v[0], v[1:], 1) + if k > m { + h[k*ldh+k-1] = v[0] + h[(k+1)*ldh+k-1] = 0 + if k < i-1 { + h[(k+2)*ldh+k-1] = 0 + } + } else if m > l { + // Use the following instead of H[k,k-1] = -H[k,k-1] + // to avoid a bug when v[1] and v[2] underflow. + h[k*ldh+k-1] *= 1 - t0 + } + t1 := t0 * v[1] + if nr == 3 { + t2 := t0 * v[2] + + // Apply G from the left to transform + // the rows of the matrix in columns k + // to i2. + for j := k; j <= i2; j++ { + sum := h[k*ldh+j] + v[1]*h[(k+1)*ldh+j] + v[2]*h[(k+2)*ldh+j] + h[k*ldh+j] -= sum * t0 + h[(k+1)*ldh+j] -= sum * t1 + h[(k+2)*ldh+j] -= sum * t2 + } + + // Apply G from the right to transform + // the columns of the matrix in rows i1 + // to min(k+3,i). + for j := i1; j <= min(k+3, i); j++ { + sum := h[j*ldh+k] + v[1]*h[j*ldh+k+1] + v[2]*h[j*ldh+k+2] + h[j*ldh+k] -= sum * t0 + h[j*ldh+k+1] -= sum * t1 + h[j*ldh+k+2] -= sum * t2 + } + + if wantz { + // Accumulate transformations in the matrix Z. + for j := iloz; j <= ihiz; j++ { + sum := z[j*ldz+k] + v[1]*z[j*ldz+k+1] + v[2]*z[j*ldz+k+2] + z[j*ldz+k] -= sum * t0 + z[j*ldz+k+1] -= sum * t1 + z[j*ldz+k+2] -= sum * t2 + } + } + } else if nr == 2 { + // Apply G from the left to transform + // the rows of the matrix in columns k + // to i2. + for j := k; j <= i2; j++ { + sum := h[k*ldh+j] + v[1]*h[(k+1)*ldh+j] + h[k*ldh+j] -= sum * t0 + h[(k+1)*ldh+j] -= sum * t1 + } + + // Apply G from the right to transform + // the columns of the matrix in rows i1 + // to min(k+3,i). + for j := i1; j <= i; j++ { + sum := h[j*ldh+k] + v[1]*h[j*ldh+k+1] + h[j*ldh+k] -= sum * t0 + h[j*ldh+k+1] -= sum * t1 + } + + if wantz { + // Accumulate transformations in the matrix Z. + for j := iloz; j <= ihiz; j++ { + sum := z[j*ldz+k] + v[1]*z[j*ldz+k+1] + z[j*ldz+k] -= sum * t0 + z[j*ldz+k+1] -= sum * t1 + } + } + } + } + } + + if !converged { + // The QR iteration finished without splitting off a + // submatrix of order 1 or 2. + return i + 1 + } + + if l == i { + // H[i,i-1] is negligible: one eigenvalue has converged. + wr[i] = h[i*ldh+i] + wi[i] = 0 + } else if l == i-1 { + // H[i-1,i-2] is negligible: a pair of eigenvalues have converged. + + // Transform the 2×2 submatrix to standard Schur form, + // and compute and store the eigenvalues. + var cs, sn float64 + a, b := h[(i-1)*ldh+i-1], h[(i-1)*ldh+i] + c, d := h[i*ldh+i-1], h[i*ldh+i] + a, b, c, d, wr[i-1], wi[i-1], wr[i], wi[i], cs, sn = impl.Dlanv2(a, b, c, d) + h[(i-1)*ldh+i-1], h[(i-1)*ldh+i] = a, b + h[i*ldh+i-1], h[i*ldh+i] = c, d + + if wantt { + // Apply the transformation to the rest of H. + if i2 > i { + bi.Drot(i2-i, h[(i-1)*ldh+i+1:], 1, h[i*ldh+i+1:], 1, cs, sn) + } + bi.Drot(i-i1-1, h[i1*ldh+i-1:], ldh, h[i1*ldh+i:], ldh, cs, sn) + } + + if wantz { + // Apply the transformation to Z. + bi.Drot(nz, z[iloz*ldz+i-1:], ldz, z[iloz*ldz+i:], ldz, cs, sn) + } + } + + // Return to start of the main loop with new value of i. + i = l - 1 + } + return 0 +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dlaqr04.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dlaqr04.go new file mode 100644 index 0000000000..3faaa2fc19 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dlaqr04.go @@ -0,0 +1,493 @@ +// Copyright ©2016 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" +) + +// Dlaqr04 computes the eigenvalues of a block of an n×n upper Hessenberg matrix +// H, and optionally the matrices T and Z from the Schur decomposition +// +// H = Z T Zᵀ +// +// where T is an upper quasi-triangular matrix (the Schur form), and Z is the +// orthogonal matrix of Schur vectors. +// +// wantt indicates whether the full Schur form T is required. If wantt is false, +// then only enough of H will be updated to preserve the eigenvalues. +// +// wantz indicates whether the n×n matrix of Schur vectors Z is required. If it +// is true, the orthogonal similarity transformation will be accumulated into +// Z[iloz:ihiz+1,ilo:ihi+1], otherwise Z will not be referenced. +// +// ilo and ihi determine the block of H on which Dlaqr04 operates. It must hold that +// +// 0 <= ilo <= ihi < n if n > 0, +// ilo == 0 and ihi == -1 if n == 0, +// +// and the block must be isolated, that is, +// +// ilo == 0 or H[ilo,ilo-1] == 0, +// ihi == n-1 or H[ihi+1,ihi] == 0, +// +// otherwise Dlaqr04 will panic. +// +// wr and wi must have length ihi+1. +// +// iloz and ihiz specify the rows of Z to which transformations will be applied +// if wantz is true. It must hold that +// +// 0 <= iloz <= ilo, and ihi <= ihiz < n, +// +// otherwise Dlaqr04 will panic. +// +// work must have length at least lwork and lwork must be +// +// lwork >= 1 if n <= 11, +// lwork >= n if n > 11, +// +// otherwise Dlaqr04 will panic. lwork as large as 6*n may be required for +// optimal performance. On return, work[0] will contain the optimal value of +// lwork. +// +// If lwork is -1, instead of performing Dlaqr04, the function only estimates the +// optimal workspace size and stores it into work[0]. Neither h nor z are +// accessed. +// +// recur is the non-negative recursion depth. For recur > 0, Dlaqr04 behaves +// as DLAQR0, for recur == 0 it behaves as DLAQR4. +// +// unconverged indicates whether Dlaqr04 computed all the eigenvalues of H[ilo:ihi+1,ilo:ihi+1]. +// +// If unconverged is zero and wantt is true, H will contain on return the upper +// quasi-triangular matrix T from the Schur decomposition. 2×2 diagonal blocks +// (corresponding to complex conjugate pairs of eigenvalues) will be returned in +// standard form, with H[i,i] == H[i+1,i+1] and H[i+1,i]*H[i,i+1] < 0. +// +// If unconverged is zero and if wantt is false, the contents of h on return is +// unspecified. +// +// If unconverged is zero, all the eigenvalues have been computed and their real +// and imaginary parts will be stored on return in wr[ilo:ihi+1] and +// wi[ilo:ihi+1], respectively. If two eigenvalues are computed as a complex +// conjugate pair, they are stored in consecutive elements of wr and wi, say the +// i-th and (i+1)th, with wi[i] > 0 and wi[i+1] < 0. If wantt is true, then the +// eigenvalues are stored in the same order as on the diagonal of the Schur form +// returned in H, with wr[i] = H[i,i] and, if H[i:i+2,i:i+2] is a 2×2 diagonal +// block, wi[i] = sqrt(-H[i+1,i]*H[i,i+1]) and wi[i+1] = -wi[i]. +// +// If unconverged is positive, some eigenvalues have not converged, and +// wr[unconverged:ihi+1] and wi[unconverged:ihi+1] will contain those +// eigenvalues which have been successfully computed. Failures are rare. +// +// If unconverged is positive and wantt is true, then on return +// +// (initial H)*U = U*(final H), (*) +// +// where U is an orthogonal matrix. The final H is upper Hessenberg and +// H[unconverged:ihi+1,unconverged:ihi+1] is upper quasi-triangular. +// +// If unconverged is positive and wantt is false, on return the remaining +// unconverged eigenvalues are the eigenvalues of the upper Hessenberg matrix +// H[ilo:unconverged,ilo:unconverged]. +// +// If unconverged is positive and wantz is true, then on return +// +// (final Z) = (initial Z)*U, +// +// where U is the orthogonal matrix in (*) regardless of the value of wantt. +// +// References: +// +// [1] K. Braman, R. Byers, R. Mathias. The Multishift QR Algorithm. Part I: +// Maintaining Well-Focused Shifts and Level 3 Performance. SIAM J. Matrix +// Anal. Appl. 23(4) (2002), pp. 929—947 +// URL: http://dx.doi.org/10.1137/S0895479801384573 +// [2] K. Braman, R. Byers, R. Mathias. The Multishift QR Algorithm. Part II: +// Aggressive Early Deflation. SIAM J. Matrix Anal. Appl. 23(4) (2002), pp. 948—973 +// URL: http://dx.doi.org/10.1137/S0895479801384585 +// +// Dlaqr04 is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dlaqr04(wantt, wantz bool, n, ilo, ihi int, h []float64, ldh int, wr, wi []float64, iloz, ihiz int, z []float64, ldz int, work []float64, lwork int, recur int) (unconverged int) { + const ( + // Matrices of order ntiny or smaller must be processed by + // Dlahqr because of insufficient subdiagonal scratch space. + // This is a hard limit. + ntiny = 11 + // Exceptional deflation windows: try to cure rare slow + // convergence by varying the size of the deflation window after + // kexnw iterations. + kexnw = 5 + // Exceptional shifts: try to cure rare slow convergence with + // ad-hoc exceptional shifts every kexsh iterations. + kexsh = 6 + + // See https://github.com/gonum/lapack/pull/151#discussion_r68162802 + // and the surrounding discussion for an explanation where these + // constants come from. + // TODO(vladimir-ch): Similar constants for exceptional shifts + // are used also in dlahqr.go. The first constant is different + // there, it is equal to 3. Why? And does it matter? + wilk1 = 0.75 + wilk2 = -0.4375 + ) + + switch { + case n < 0: + panic(nLT0) + case ilo < 0 || max(0, n-1) < ilo: + panic(badIlo) + case ihi < min(ilo, n-1) || n <= ihi: + panic(badIhi) + case ldh < max(1, n): + panic(badLdH) + case wantz && (iloz < 0 || ilo < iloz): + panic(badIloz) + case wantz && (ihiz < ihi || n <= ihiz): + panic(badIhiz) + case ldz < 1, wantz && ldz < n: + panic(badLdZ) + case lwork < 1 && lwork != -1: + panic(badLWork) + // TODO(vladimir-ch): Enable if and when we figure out what the minimum + // necessary lwork value is. Dlaqr04 says that the minimum is n which + // clashes with Dlaqr23's opinion about optimal work when nw <= 2 + // (independent of n). + // case lwork < n && n > ntiny && lwork != -1: + // panic(badLWork) + case len(work) < max(1, lwork): + panic(shortWork) + case recur < 0: + panic(recurLT0) + } + + // Quick return. + if n == 0 { + work[0] = 1 + return 0 + } + + if lwork != -1 { + switch { + case len(h) < (n-1)*ldh+n: + panic(shortH) + case len(wr) != ihi+1: + panic(badLenWr) + case len(wi) != ihi+1: + panic(badLenWi) + case wantz && len(z) < (n-1)*ldz+n: + panic(shortZ) + case ilo > 0 && h[ilo*ldh+ilo-1] != 0: + panic(notIsolated) + case ihi+1 < n && h[(ihi+1)*ldh+ihi] != 0: + panic(notIsolated) + } + } + + if n <= ntiny { + // Tiny matrices must use Dlahqr. + if lwork == -1 { + work[0] = 1 + return 0 + } + return impl.Dlahqr(wantt, wantz, n, ilo, ihi, h, ldh, wr, wi, iloz, ihiz, z, ldz) + } + + // Use small bulge multi-shift QR with aggressive early deflation on + // larger-than-tiny matrices. + var jbcmpz string + if wantt { + jbcmpz = "S" + } else { + jbcmpz = "E" + } + if wantz { + jbcmpz += "V" + } else { + jbcmpz += "N" + } + + var fname string + if recur > 0 { + fname = "DLAQR0" + } else { + fname = "DLAQR4" + } + // nwr is the recommended deflation window size. n is greater than 11, + // so there is enough subdiagonal workspace for nwr >= 2 as required. + // (In fact, there is enough subdiagonal space for nwr >= 3.) + // TODO(vladimir-ch): If there is enough space for nwr >= 3, should we + // use it? + nwr := impl.Ilaenv(13, fname, jbcmpz, n, ilo, ihi, lwork) + nwr = max(2, nwr) + nwr = min(ihi-ilo+1, min((n-1)/3, nwr)) + + // nsr is the recommended number of simultaneous shifts. n is greater + // than 11, so there is enough subdiagonal workspace for nsr to be even + // and greater than or equal to two as required. + nsr := impl.Ilaenv(15, fname, jbcmpz, n, ilo, ihi, lwork) + nsr = min(nsr, min((n+6)/9, ihi-ilo)) + nsr = max(2, nsr&^1) + + // Workspace query call to Dlaqr23. + impl.Dlaqr23(wantt, wantz, n, ilo, ihi, nwr+1, h, ldh, iloz, ihiz, z, ldz, + wr, wi, h, ldh, n, h, ldh, n, h, ldh, work, -1, recur) + // Optimal workspace is max(Dlaqr5, Dlaqr23). + lwkopt := max(3*nsr/2, int(work[0])) + // Quick return in case of workspace query. + if lwork == -1 { + work[0] = float64(lwkopt) + return 0 + } + + // Dlahqr/Dlaqr04 crossover point. + nmin := impl.Ilaenv(12, fname, jbcmpz, n, ilo, ihi, lwork) + nmin = max(ntiny, nmin) + + // Nibble determines when to skip a multi-shift QR sweep (Dlaqr5). + nibble := impl.Ilaenv(14, fname, jbcmpz, n, ilo, ihi, lwork) + nibble = max(0, nibble) + + // Computation mode of far-from-diagonal orthogonal updates in Dlaqr5. + kacc22 := impl.Ilaenv(16, fname, jbcmpz, n, ilo, ihi, lwork) + kacc22 = max(0, min(kacc22, 2)) + + // nwmax is the largest possible deflation window for which there is + // sufficient workspace. + nwmax := min((n-1)/3, lwork/2) + nw := nwmax // Start with maximum deflation window size. + + // nsmax is the largest number of simultaneous shifts for which there is + // sufficient workspace. + nsmax := min((n+6)/9, 2*lwork/3) &^ 1 + + ndfl := 1 // Number of iterations since last deflation. + ndec := 0 // Deflation window size decrement. + + // Main loop. + var ( + itmax = max(30, 2*kexsh) * max(10, (ihi-ilo+1)) + it = 0 + ) + for kbot := ihi; kbot >= ilo; { + if it == itmax { + unconverged = kbot + 1 + break + } + it++ + + // Locate active block. + ktop := ilo + for k := kbot; k >= ilo+1; k-- { + if h[k*ldh+k-1] == 0 { + ktop = k + break + } + } + + // Select deflation window size nw. + // + // Typical Case: + // If possible and advisable, nibble the entire active block. + // If not, use size min(nwr,nwmax) or min(nwr+1,nwmax) + // depending upon which has the smaller corresponding + // subdiagonal entry (a heuristic). + // + // Exceptional Case: + // If there have been no deflations in kexnw or more + // iterations, then vary the deflation window size. At first, + // because larger windows are, in general, more powerful than + // smaller ones, rapidly increase the window to the maximum + // possible. Then, gradually reduce the window size. + nh := kbot - ktop + 1 + nwupbd := min(nh, nwmax) + if ndfl < kexnw { + nw = min(nwupbd, nwr) + } else { + nw = min(nwupbd, 2*nw) + } + if nw < nwmax { + if nw >= nh-1 { + nw = nh + } else { + kwtop := kbot - nw + 1 + if math.Abs(h[kwtop*ldh+kwtop-1]) > math.Abs(h[(kwtop-1)*ldh+kwtop-2]) { + nw++ + } + } + } + if ndfl < kexnw { + ndec = -1 + } else if ndec >= 0 || nw >= nwupbd { + ndec++ + if nw-ndec < 2 { + ndec = 0 + } + nw -= ndec + } + + // Split workspace under the subdiagonal of H into: + // - an nw×nw work array V in the lower left-hand corner, + // - an nw×nhv horizontal work array along the bottom edge (nhv + // must be at least nw but more is better), + // - an nve×nw vertical work array along the left-hand-edge + // (nhv can be any positive integer but more is better). + kv := n - nw + kt := nw + kwv := nw + 1 + nhv := n - kwv - kt + // Aggressive early deflation. + ls, ld := impl.Dlaqr23(wantt, wantz, n, ktop, kbot, nw, + h, ldh, iloz, ihiz, z, ldz, wr[:kbot+1], wi[:kbot+1], + h[kv*ldh:], ldh, nhv, h[kv*ldh+kt:], ldh, nhv, h[kwv*ldh:], ldh, work, lwork, recur) + + // Adjust kbot accounting for new deflations. + kbot -= ld + // ks points to the shifts. + ks := kbot - ls + 1 + + // Skip an expensive QR sweep if there is a (partly heuristic) + // reason to expect that many eigenvalues will deflate without + // it. Here, the QR sweep is skipped if many eigenvalues have + // just been deflated or if the remaining active block is small. + if ld > 0 && (100*ld > nw*nibble || kbot-ktop+1 <= min(nmin, nwmax)) { + // ld is positive, note progress. + ndfl = 1 + continue + } + + // ns is the nominal number of simultaneous shifts. This may be + // lowered (slightly) if Dlaqr23 did not provide that many + // shifts. + ns := min(min(nsmax, nsr), max(2, kbot-ktop)) &^ 1 + + // If there have been no deflations in a multiple of kexsh + // iterations, then try exceptional shifts. Otherwise use shifts + // provided by Dlaqr23 above or from the eigenvalues of a + // trailing principal submatrix. + if ndfl%kexsh == 0 { + ks = kbot - ns + 1 + for i := kbot; i > max(ks, ktop+1); i -= 2 { + ss := math.Abs(h[i*ldh+i-1]) + math.Abs(h[(i-1)*ldh+i-2]) + aa := wilk1*ss + h[i*ldh+i] + _, _, _, _, wr[i-1], wi[i-1], wr[i], wi[i], _, _ = + impl.Dlanv2(aa, ss, wilk2*ss, aa) + } + if ks == ktop { + wr[ks+1] = h[(ks+1)*ldh+ks+1] + wi[ks+1] = 0 + wr[ks] = wr[ks+1] + wi[ks] = wi[ks+1] + } + } else { + // If we got ns/2 or fewer shifts, use Dlahqr or recur + // into Dlaqr04 on a trailing principal submatrix to get + // more. Since ns <= nsmax <=(n+6)/9, there is enough + // space below the subdiagonal to fit an ns×ns scratch + // array. + if kbot-ks+1 <= ns/2 { + ks = kbot - ns + 1 + kt = n - ns + impl.Dlacpy(blas.All, ns, ns, h[ks*ldh+ks:], ldh, h[kt*ldh:], ldh) + if ns > nmin && recur > 0 { + ks += impl.Dlaqr04(false, false, ns, 1, ns-1, h[kt*ldh:], ldh, + wr[ks:ks+ns], wi[ks:ks+ns], 0, 0, nil, 0, work, lwork, recur-1) + } else { + ks += impl.Dlahqr(false, false, ns, 0, ns-1, h[kt*ldh:], ldh, + wr[ks:ks+ns], wi[ks:ks+ns], 0, 0, nil, 1) + } + // In case of a rare QR failure use eigenvalues + // of the trailing 2×2 principal submatrix. + if ks >= kbot { + aa := h[(kbot-1)*ldh+kbot-1] + bb := h[(kbot-1)*ldh+kbot] + cc := h[kbot*ldh+kbot-1] + dd := h[kbot*ldh+kbot] + _, _, _, _, wr[kbot-1], wi[kbot-1], wr[kbot], wi[kbot], _, _ = + impl.Dlanv2(aa, bb, cc, dd) + ks = kbot - 1 + } + } + + if kbot-ks+1 > ns { + // Sorting the shifts helps a little. Bubble + // sort keeps complex conjugate pairs together. + sorted := false + for k := kbot; k > ks; k-- { + if sorted { + break + } + sorted = true + for i := ks; i < k; i++ { + if math.Abs(wr[i])+math.Abs(wi[i]) >= math.Abs(wr[i+1])+math.Abs(wi[i+1]) { + continue + } + sorted = false + wr[i], wr[i+1] = wr[i+1], wr[i] + wi[i], wi[i+1] = wi[i+1], wi[i] + } + } + } + + // Shuffle shifts into pairs of real shifts and pairs of + // complex conjugate shifts using the fact that complex + // conjugate shifts are already adjacent to one another. + // TODO(vladimir-ch): The shuffling here could probably + // be removed but I'm not sure right now and it's safer + // to leave it. + for i := kbot; i > ks+1; i -= 2 { + if wi[i] == -wi[i-1] { + continue + } + wr[i], wr[i-1], wr[i-2] = wr[i-1], wr[i-2], wr[i] + wi[i], wi[i-1], wi[i-2] = wi[i-1], wi[i-2], wi[i] + } + } + + // If there are only two shifts and both are real, then use only one. + if kbot-ks+1 == 2 && wi[kbot] == 0 { + if math.Abs(wr[kbot]-h[kbot*ldh+kbot]) < math.Abs(wr[kbot-1]-h[kbot*ldh+kbot]) { + wr[kbot-1] = wr[kbot] + } else { + wr[kbot] = wr[kbot-1] + } + } + + // Use up to ns of the smallest magnitude shifts. If there + // aren't ns shifts available, then use them all, possibly + // dropping one to make the number of shifts even. + ns = min(ns, kbot-ks+1) &^ 1 + ks = kbot - ns + 1 + + // Split workspace under the subdiagonal into: + // - a kdu×kdu work array U in the lower left-hand-corner, + // - a kdu×nhv horizontal work array WH along the bottom edge + // (nhv must be at least kdu but more is better), + // - an nhv×kdu vertical work array WV along the left-hand-edge + // (nhv must be at least kdu but more is better). + kdu := 3*ns - 3 + ku := n - kdu + kwh := kdu + kwv = kdu + 3 + nhv = n - kwv - kdu + // Small-bulge multi-shift QR sweep. + impl.Dlaqr5(wantt, wantz, kacc22, n, ktop, kbot, ns, + wr[ks:ks+ns], wi[ks:ks+ns], h, ldh, iloz, ihiz, z, ldz, + work, 3, h[ku*ldh:], ldh, nhv, h[kwv*ldh:], ldh, nhv, h[ku*ldh+kwh:], ldh) + + // Note progress (or the lack of it). + if ld > 0 { + ndfl = 1 + } else { + ndfl++ + } + } + + work[0] = float64(lwkopt) + return unconverged +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dlaqr5.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dlaqr5.go new file mode 100644 index 0000000000..43b425b8d9 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dlaqr5.go @@ -0,0 +1,648 @@ +// Copyright ©2016 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" +) + +// Dlaqr5 performs a single small-bulge multi-shift QR sweep on an isolated +// block of a Hessenberg matrix. +// +// wantt and wantz determine whether the quasi-triangular Schur factor and the +// orthogonal Schur factor, respectively, will be computed. +// +// kacc22 specifies the computation mode of far-from-diagonal orthogonal +// updates. Permitted values are: +// +// 0: Dlaqr5 will not accumulate reflections and will not use matrix-matrix +// multiply to update far-from-diagonal matrix entries. +// 1: Dlaqr5 will accumulate reflections and use matrix-matrix multiply to +// update far-from-diagonal matrix entries. +// 2: Dlaqr5 will accumulate reflections, use matrix-matrix multiply to update +// far-from-diagonal matrix entries, and take advantage of 2×2 block +// structure during matrix multiplies. +// +// For other values of kacc2 Dlaqr5 will panic. +// +// n is the order of the Hessenberg matrix H. +// +// ktop and kbot are indices of the first and last row and column of an isolated +// diagonal block upon which the QR sweep will be applied. It must hold that +// +// ktop == 0, or 0 < ktop <= n-1 and H[ktop, ktop-1] == 0, and +// kbot == n-1, or 0 <= kbot < n-1 and H[kbot+1, kbot] == 0, +// +// otherwise Dlaqr5 will panic. +// +// nshfts is the number of simultaneous shifts. It must be positive and even, +// otherwise Dlaqr5 will panic. +// +// sr and si contain the real and imaginary parts, respectively, of the shifts +// of origin that define the multi-shift QR sweep. On return both slices may be +// reordered by Dlaqr5. Their length must be equal to nshfts, otherwise Dlaqr5 +// will panic. +// +// h and ldh represent the Hessenberg matrix H of size n×n. On return +// multi-shift QR sweep with shifts sr+i*si has been applied to the isolated +// diagonal block in rows and columns ktop through kbot, inclusive. +// +// iloz and ihiz specify the rows of Z to which transformations will be applied +// if wantz is true. It must hold that 0 <= iloz <= ihiz < n, otherwise Dlaqr5 +// will panic. +// +// z and ldz represent the matrix Z of size n×n. If wantz is true, the QR sweep +// orthogonal similarity transformation is accumulated into +// z[iloz:ihiz,iloz:ihiz] from the right, otherwise z not referenced. +// +// v and ldv represent an auxiliary matrix V of size (nshfts/2)×3. Note that V +// is transposed with respect to the reference netlib implementation. +// +// u and ldu represent an auxiliary matrix of size (3*nshfts-3)×(3*nshfts-3). +// +// wh and ldwh represent an auxiliary matrix of size (3*nshfts-3)×nh. +// +// wv and ldwv represent an auxiliary matrix of size nv×(3*nshfts-3). +// +// Dlaqr5 is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dlaqr5(wantt, wantz bool, kacc22 int, n, ktop, kbot, nshfts int, sr, si []float64, h []float64, ldh int, iloz, ihiz int, z []float64, ldz int, v []float64, ldv int, u []float64, ldu int, nv int, wv []float64, ldwv int, nh int, wh []float64, ldwh int) { + switch { + case kacc22 != 0 && kacc22 != 1 && kacc22 != 2: + panic(badKacc22) + case n < 0: + panic(nLT0) + case ktop < 0 || n <= ktop: + panic(badKtop) + case kbot < 0 || n <= kbot: + panic(badKbot) + + case nshfts < 0: + panic(nshftsLT0) + case nshfts&0x1 != 0: + panic(nshftsOdd) + case len(sr) != nshfts: + panic(badLenSr) + case len(si) != nshfts: + panic(badLenSi) + + case ldh < max(1, n): + panic(badLdH) + case len(h) < (n-1)*ldh+n: + panic(shortH) + + case wantz && ihiz >= n: + panic(badIhiz) + case wantz && iloz < 0 || ihiz < iloz: + panic(badIloz) + case ldz < 1, wantz && ldz < n: + panic(badLdZ) + case wantz && len(z) < (n-1)*ldz+n: + panic(shortZ) + + case ldv < 3: + // V is transposed w.r.t. reference lapack. + panic(badLdV) + case len(v) < (nshfts/2-1)*ldv+3: + panic(shortV) + + case ldu < max(1, 3*nshfts-3): + panic(badLdU) + case len(u) < (3*nshfts-3-1)*ldu+3*nshfts-3: + panic(shortU) + + case nv < 0: + panic(nvLT0) + case ldwv < max(1, 3*nshfts-3): + panic(badLdWV) + case len(wv) < (nv-1)*ldwv+3*nshfts-3: + panic(shortWV) + + case nh < 0: + panic(nhLT0) + case ldwh < max(1, nh): + panic(badLdWH) + case len(wh) < (3*nshfts-3-1)*ldwh+nh: + panic(shortWH) + + case ktop > 0 && h[ktop*ldh+ktop-1] != 0: + panic(notIsolated) + case kbot < n-1 && h[(kbot+1)*ldh+kbot] != 0: + panic(notIsolated) + } + + // If there are no shifts, then there is nothing to do. + if nshfts < 2 { + return + } + // If the active block is empty or 1×1, then there is nothing to do. + if ktop >= kbot { + return + } + + // Shuffle shifts into pairs of real shifts and pairs of complex + // conjugate shifts assuming complex conjugate shifts are already + // adjacent to one another. + for i := 0; i < nshfts-2; i += 2 { + if si[i] == -si[i+1] { + continue + } + sr[i], sr[i+1], sr[i+2] = sr[i+1], sr[i+2], sr[i] + si[i], si[i+1], si[i+2] = si[i+1], si[i+2], si[i] + } + + // Note: lapack says that nshfts must be even but allows it to be odd + // anyway. We panic above if nshfts is not even, so reducing it by one + // is unnecessary. The only caller Dlaqr04 uses only even nshfts. + // + // The original comment and code from lapack-3.6.0/SRC/dlaqr5.f:341: + // * ==== NSHFTS is supposed to be even, but if it is odd, + // * . then simply reduce it by one. The shuffle above + // * . ensures that the dropped shift is real and that + // * . the remaining shifts are paired. ==== + // * + // NS = NSHFTS - MOD( NSHFTS, 2 ) + ns := nshfts + + safmin := dlamchS + ulp := dlamchP + smlnum := safmin * float64(n) / ulp + + // Use accumulated reflections to update far-from-diagonal entries? + accum := kacc22 == 1 || kacc22 == 2 + // If so, exploit the 2×2 block structure? + blk22 := ns > 2 && kacc22 == 2 + + // Clear trash. + if ktop+2 <= kbot { + h[(ktop+2)*ldh+ktop] = 0 + } + + // nbmps = number of 2-shift bulges in the chain. + nbmps := ns / 2 + + // kdu = width of slab. + kdu := 6*nbmps - 3 + + // Create and chase chains of nbmps bulges. + for incol := 3*(1-nbmps) + ktop - 1; incol <= kbot-2; incol += 3*nbmps - 2 { + ndcol := incol + kdu + if accum { + impl.Dlaset(blas.All, kdu, kdu, 0, 1, u, ldu) + } + + // Near-the-diagonal bulge chase. The following loop performs + // the near-the-diagonal part of a small bulge multi-shift QR + // sweep. Each 6*nbmps-2 column diagonal chunk extends from + // column incol to column ndcol (including both column incol and + // column ndcol). The following loop chases a 3*nbmps column + // long chain of nbmps bulges 3*nbmps-2 columns to the right. + // (incol may be less than ktop and ndcol may be greater than + // kbot indicating phantom columns from which to chase bulges + // before they are actually introduced or to which to chase + // bulges beyond column kbot.) + for krcol := incol; krcol <= min(incol+3*nbmps-3, kbot-2); krcol++ { + // Bulges number mtop to mbot are active double implicit + // shift bulges. There may or may not also be small 2×2 + // bulge, if there is room. The inactive bulges (if any) + // must wait until the active bulges have moved down the + // diagonal to make room. The phantom matrix paradigm + // described above helps keep track. + + mtop := max(0, ((ktop-1)-krcol+2)/3) + mbot := min(nbmps, (kbot-krcol)/3) - 1 + m22 := mbot + 1 + bmp22 := (mbot < nbmps-1) && (krcol+3*m22 == kbot-2) + + // Generate reflections to chase the chain right one + // column. (The minimum value of k is ktop-1.) + for m := mtop; m <= mbot; m++ { + k := krcol + 3*m + if k == ktop-1 { + impl.Dlaqr1(3, h[ktop*ldh+ktop:], ldh, + sr[2*m], si[2*m], sr[2*m+1], si[2*m+1], + v[m*ldv:m*ldv+3]) + alpha := v[m*ldv] + _, v[m*ldv] = impl.Dlarfg(3, alpha, v[m*ldv+1:m*ldv+3], 1) + continue + } + beta := h[(k+1)*ldh+k] + v[m*ldv+1] = h[(k+2)*ldh+k] + v[m*ldv+2] = h[(k+3)*ldh+k] + beta, v[m*ldv] = impl.Dlarfg(3, beta, v[m*ldv+1:m*ldv+3], 1) + + // A bulge may collapse because of vigilant deflation or + // destructive underflow. In the underflow case, try the + // two-small-subdiagonals trick to try to reinflate the + // bulge. + if h[(k+3)*ldh+k] != 0 || h[(k+3)*ldh+k+1] != 0 || h[(k+3)*ldh+k+2] == 0 { + // Typical case: not collapsed (yet). + h[(k+1)*ldh+k] = beta + h[(k+2)*ldh+k] = 0 + h[(k+3)*ldh+k] = 0 + continue + } + + // Atypical case: collapsed. Attempt to reintroduce + // ignoring H[k+1,k] and H[k+2,k]. If the fill + // resulting from the new reflector is too large, + // then abandon it. Otherwise, use the new one. + var vt [3]float64 + impl.Dlaqr1(3, h[(k+1)*ldh+k+1:], ldh, sr[2*m], + si[2*m], sr[2*m+1], si[2*m+1], vt[:]) + alpha := vt[0] + _, vt[0] = impl.Dlarfg(3, alpha, vt[1:3], 1) + refsum := vt[0] * (h[(k+1)*ldh+k] + vt[1]*h[(k+2)*ldh+k]) + + dsum := math.Abs(h[k*ldh+k]) + math.Abs(h[(k+1)*ldh+k+1]) + math.Abs(h[(k+2)*ldh+k+2]) + if math.Abs(h[(k+2)*ldh+k]-refsum*vt[1])+math.Abs(refsum*vt[2]) > ulp*dsum { + // Starting a new bulge here would create + // non-negligible fill. Use the old one with + // trepidation. + h[(k+1)*ldh+k] = beta + h[(k+2)*ldh+k] = 0 + h[(k+3)*ldh+k] = 0 + continue + } else { + // Starting a new bulge here would create + // only negligible fill. Replace the old + // reflector with the new one. + h[(k+1)*ldh+k] -= refsum + h[(k+2)*ldh+k] = 0 + h[(k+3)*ldh+k] = 0 + v[m*ldv] = vt[0] + v[m*ldv+1] = vt[1] + v[m*ldv+2] = vt[2] + } + } + + // Generate a 2×2 reflection, if needed. + if bmp22 { + k := krcol + 3*m22 + if k == ktop-1 { + impl.Dlaqr1(2, h[(k+1)*ldh+k+1:], ldh, + sr[2*m22], si[2*m22], sr[2*m22+1], si[2*m22+1], + v[m22*ldv:m22*ldv+2]) + beta := v[m22*ldv] + _, v[m22*ldv] = impl.Dlarfg(2, beta, v[m22*ldv+1:m22*ldv+2], 1) + } else { + beta := h[(k+1)*ldh+k] + v[m22*ldv+1] = h[(k+2)*ldh+k] + beta, v[m22*ldv] = impl.Dlarfg(2, beta, v[m22*ldv+1:m22*ldv+2], 1) + h[(k+1)*ldh+k] = beta + h[(k+2)*ldh+k] = 0 + } + } + + // Multiply H by reflections from the left. + var jbot int + switch { + case accum: + jbot = min(ndcol, kbot) + case wantt: + jbot = n - 1 + default: + jbot = kbot + } + for j := max(ktop, krcol); j <= jbot; j++ { + mend := min(mbot+1, (j-krcol+2)/3) - 1 + for m := mtop; m <= mend; m++ { + k := krcol + 3*m + refsum := v[m*ldv] * (h[(k+1)*ldh+j] + + v[m*ldv+1]*h[(k+2)*ldh+j] + v[m*ldv+2]*h[(k+3)*ldh+j]) + h[(k+1)*ldh+j] -= refsum + h[(k+2)*ldh+j] -= refsum * v[m*ldv+1] + h[(k+3)*ldh+j] -= refsum * v[m*ldv+2] + } + } + if bmp22 { + k := krcol + 3*m22 + for j := max(k+1, ktop); j <= jbot; j++ { + refsum := v[m22*ldv] * (h[(k+1)*ldh+j] + v[m22*ldv+1]*h[(k+2)*ldh+j]) + h[(k+1)*ldh+j] -= refsum + h[(k+2)*ldh+j] -= refsum * v[m22*ldv+1] + } + } + + // Multiply H by reflections from the right. Delay filling in the last row + // until the vigilant deflation check is complete. + var jtop int + switch { + case accum: + jtop = max(ktop, incol) + case wantt: + jtop = 0 + default: + jtop = ktop + } + for m := mtop; m <= mbot; m++ { + if v[m*ldv] == 0 { + continue + } + k := krcol + 3*m + for j := jtop; j <= min(kbot, k+3); j++ { + refsum := v[m*ldv] * (h[j*ldh+k+1] + + v[m*ldv+1]*h[j*ldh+k+2] + v[m*ldv+2]*h[j*ldh+k+3]) + h[j*ldh+k+1] -= refsum + h[j*ldh+k+2] -= refsum * v[m*ldv+1] + h[j*ldh+k+3] -= refsum * v[m*ldv+2] + } + if accum { + // Accumulate U. (If necessary, update Z later with an + // efficient matrix-matrix multiply.) + kms := k - incol + for j := max(0, ktop-incol-1); j < kdu; j++ { + refsum := v[m*ldv] * (u[j*ldu+kms] + + v[m*ldv+1]*u[j*ldu+kms+1] + v[m*ldv+2]*u[j*ldu+kms+2]) + u[j*ldu+kms] -= refsum + u[j*ldu+kms+1] -= refsum * v[m*ldv+1] + u[j*ldu+kms+2] -= refsum * v[m*ldv+2] + } + } else if wantz { + // U is not accumulated, so update Z now by multiplying by + // reflections from the right. + for j := iloz; j <= ihiz; j++ { + refsum := v[m*ldv] * (z[j*ldz+k+1] + + v[m*ldv+1]*z[j*ldz+k+2] + v[m*ldv+2]*z[j*ldz+k+3]) + z[j*ldz+k+1] -= refsum + z[j*ldz+k+2] -= refsum * v[m*ldv+1] + z[j*ldz+k+3] -= refsum * v[m*ldv+2] + } + } + } + + // Special case: 2×2 reflection (if needed). + if bmp22 && v[m22*ldv] != 0 { + k := krcol + 3*m22 + for j := jtop; j <= min(kbot, k+3); j++ { + refsum := v[m22*ldv] * (h[j*ldh+k+1] + v[m22*ldv+1]*h[j*ldh+k+2]) + h[j*ldh+k+1] -= refsum + h[j*ldh+k+2] -= refsum * v[m22*ldv+1] + } + if accum { + kms := k - incol + for j := max(0, ktop-incol-1); j < kdu; j++ { + refsum := v[m22*ldv] * (u[j*ldu+kms] + v[m22*ldv+1]*u[j*ldu+kms+1]) + u[j*ldu+kms] -= refsum + u[j*ldu+kms+1] -= refsum * v[m22*ldv+1] + } + } else if wantz { + for j := iloz; j <= ihiz; j++ { + refsum := v[m22*ldv] * (z[j*ldz+k+1] + v[m22*ldv+1]*z[j*ldz+k+2]) + z[j*ldz+k+1] -= refsum + z[j*ldz+k+2] -= refsum * v[m22*ldv+1] + } + } + } + + // Vigilant deflation check. + mstart := mtop + if krcol+3*mstart < ktop { + mstart++ + } + mend := mbot + if bmp22 { + mend++ + } + if krcol == kbot-2 { + mend++ + } + for m := mstart; m <= mend; m++ { + k := min(kbot-1, krcol+3*m) + + // The following convergence test requires that the tradition + // small-compared-to-nearby-diagonals criterion and the Ahues & + // Tisseur (LAWN 122, 1997) criteria both be satisfied. The latter + // improves accuracy in some examples. Falling back on an alternate + // convergence criterion when tst1 or tst2 is zero (as done here) is + // traditional but probably unnecessary. + + if h[(k+1)*ldh+k] == 0 { + continue + } + tst1 := math.Abs(h[k*ldh+k]) + math.Abs(h[(k+1)*ldh+k+1]) + if tst1 == 0 { + if k >= ktop+1 { + tst1 += math.Abs(h[k*ldh+k-1]) + } + if k >= ktop+2 { + tst1 += math.Abs(h[k*ldh+k-2]) + } + if k >= ktop+3 { + tst1 += math.Abs(h[k*ldh+k-3]) + } + if k <= kbot-2 { + tst1 += math.Abs(h[(k+2)*ldh+k+1]) + } + if k <= kbot-3 { + tst1 += math.Abs(h[(k+3)*ldh+k+1]) + } + if k <= kbot-4 { + tst1 += math.Abs(h[(k+4)*ldh+k+1]) + } + } + if math.Abs(h[(k+1)*ldh+k]) <= math.Max(smlnum, ulp*tst1) { + h12 := math.Max(math.Abs(h[(k+1)*ldh+k]), math.Abs(h[k*ldh+k+1])) + h21 := math.Min(math.Abs(h[(k+1)*ldh+k]), math.Abs(h[k*ldh+k+1])) + h11 := math.Max(math.Abs(h[(k+1)*ldh+k+1]), math.Abs(h[k*ldh+k]-h[(k+1)*ldh+k+1])) + h22 := math.Min(math.Abs(h[(k+1)*ldh+k+1]), math.Abs(h[k*ldh+k]-h[(k+1)*ldh+k+1])) + scl := h11 + h12 + tst2 := h22 * (h11 / scl) + if tst2 == 0 || h21*(h12/scl) <= math.Max(smlnum, ulp*tst2) { + h[(k+1)*ldh+k] = 0 + } + } + } + + // Fill in the last row of each bulge. + mend = min(nbmps, (kbot-krcol-1)/3) - 1 + for m := mtop; m <= mend; m++ { + k := krcol + 3*m + refsum := v[m*ldv] * v[m*ldv+2] * h[(k+4)*ldh+k+3] + h[(k+4)*ldh+k+1] = -refsum + h[(k+4)*ldh+k+2] = -refsum * v[m*ldv+1] + h[(k+4)*ldh+k+3] -= refsum * v[m*ldv+2] + } + } + + // Use U (if accumulated) to update far-from-diagonal entries in H. + // If required, use U to update Z as well. + if !accum { + continue + } + var jtop, jbot int + if wantt { + jtop = 0 + jbot = n - 1 + } else { + jtop = ktop + jbot = kbot + } + bi := blas64.Implementation() + if !blk22 || incol < ktop || kbot < ndcol || ns <= 2 { + // Updates not exploiting the 2×2 block structure of U. k0 and nu keep track + // of the location and size of U in the special cases of introducing bulges + // and chasing bulges off the bottom. In these special cases and in case the + // number of shifts is ns = 2, there is no 2×2 block structure to exploit. + + k0 := max(0, ktop-incol-1) + nu := kdu - max(0, ndcol-kbot) - k0 + + // Horizontal multiply. + for jcol := min(ndcol, kbot) + 1; jcol <= jbot; jcol += nh { + jlen := min(nh, jbot-jcol+1) + bi.Dgemm(blas.Trans, blas.NoTrans, nu, jlen, nu, + 1, u[k0*ldu+k0:], ldu, + h[(incol+k0+1)*ldh+jcol:], ldh, + 0, wh, ldwh) + impl.Dlacpy(blas.All, nu, jlen, wh, ldwh, h[(incol+k0+1)*ldh+jcol:], ldh) + } + + // Vertical multiply. + for jrow := jtop; jrow <= max(ktop, incol)-1; jrow += nv { + jlen := min(nv, max(ktop, incol)-jrow) + bi.Dgemm(blas.NoTrans, blas.NoTrans, jlen, nu, nu, + 1, h[jrow*ldh+incol+k0+1:], ldh, + u[k0*ldu+k0:], ldu, + 0, wv, ldwv) + impl.Dlacpy(blas.All, jlen, nu, wv, ldwv, h[jrow*ldh+incol+k0+1:], ldh) + } + + // Z multiply (also vertical). + if wantz { + for jrow := iloz; jrow <= ihiz; jrow += nv { + jlen := min(nv, ihiz-jrow+1) + bi.Dgemm(blas.NoTrans, blas.NoTrans, jlen, nu, nu, + 1, z[jrow*ldz+incol+k0+1:], ldz, + u[k0*ldu+k0:], ldu, + 0, wv, ldwv) + impl.Dlacpy(blas.All, jlen, nu, wv, ldwv, z[jrow*ldz+incol+k0+1:], ldz) + } + } + + continue + } + + // Updates exploiting U's 2×2 block structure. + + // i2, i4, j2, j4 are the last rows and columns of the blocks. + i2 := (kdu + 1) / 2 + i4 := kdu + j2 := i4 - i2 + j4 := kdu + + // kzs and knz deal with the band of zeros along the diagonal of one of the + // triangular blocks. + kzs := (j4 - j2) - (ns + 1) + knz := ns + 1 + + // Horizontal multiply. + for jcol := min(ndcol, kbot) + 1; jcol <= jbot; jcol += nh { + jlen := min(nh, jbot-jcol+1) + + // Copy bottom of H to top+kzs of scratch (the first kzs + // rows get multiplied by zero). + impl.Dlacpy(blas.All, knz, jlen, h[(incol+1+j2)*ldh+jcol:], ldh, wh[kzs*ldwh:], ldwh) + + // Multiply by U21ᵀ. + impl.Dlaset(blas.All, kzs, jlen, 0, 0, wh, ldwh) + bi.Dtrmm(blas.Left, blas.Upper, blas.Trans, blas.NonUnit, knz, jlen, + 1, u[j2*ldu+kzs:], ldu, wh[kzs*ldwh:], ldwh) + + // Multiply top of H by U11ᵀ. + bi.Dgemm(blas.Trans, blas.NoTrans, i2, jlen, j2, + 1, u, ldu, h[(incol+1)*ldh+jcol:], ldh, + 1, wh, ldwh) + + // Copy top of H to bottom of WH. + impl.Dlacpy(blas.All, j2, jlen, h[(incol+1)*ldh+jcol:], ldh, wh[i2*ldwh:], ldwh) + + // Multiply by U21ᵀ. + bi.Dtrmm(blas.Left, blas.Lower, blas.Trans, blas.NonUnit, j2, jlen, + 1, u[i2:], ldu, wh[i2*ldwh:], ldwh) + + // Multiply by U22. + bi.Dgemm(blas.Trans, blas.NoTrans, i4-i2, jlen, j4-j2, + 1, u[j2*ldu+i2:], ldu, h[(incol+1+j2)*ldh+jcol:], ldh, + 1, wh[i2*ldwh:], ldwh) + + // Copy it back. + impl.Dlacpy(blas.All, kdu, jlen, wh, ldwh, h[(incol+1)*ldh+jcol:], ldh) + } + + // Vertical multiply. + for jrow := jtop; jrow <= max(incol, ktop)-1; jrow += nv { + jlen := min(nv, max(incol, ktop)-jrow) + + // Copy right of H to scratch (the first kzs columns get multiplied + // by zero). + impl.Dlacpy(blas.All, jlen, knz, h[jrow*ldh+incol+1+j2:], ldh, wv[kzs:], ldwv) + + // Multiply by U21. + impl.Dlaset(blas.All, jlen, kzs, 0, 0, wv, ldwv) + bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.NonUnit, jlen, knz, + 1, u[j2*ldu+kzs:], ldu, wv[kzs:], ldwv) + + // Multiply by U11. + bi.Dgemm(blas.NoTrans, blas.NoTrans, jlen, i2, j2, + 1, h[jrow*ldh+incol+1:], ldh, u, ldu, + 1, wv, ldwv) + + // Copy left of H to right of scratch. + impl.Dlacpy(blas.All, jlen, j2, h[jrow*ldh+incol+1:], ldh, wv[i2:], ldwv) + + // Multiply by U21. + bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.NonUnit, jlen, i4-i2, + 1, u[i2:], ldu, wv[i2:], ldwv) + + // Multiply by U22. + bi.Dgemm(blas.NoTrans, blas.NoTrans, jlen, i4-i2, j4-j2, + 1, h[jrow*ldh+incol+1+j2:], ldh, u[j2*ldu+i2:], ldu, + 1, wv[i2:], ldwv) + + // Copy it back. + impl.Dlacpy(blas.All, jlen, kdu, wv, ldwv, h[jrow*ldh+incol+1:], ldh) + } + + if !wantz { + continue + } + // Multiply Z (also vertical). + for jrow := iloz; jrow <= ihiz; jrow += nv { + jlen := min(nv, ihiz-jrow+1) + + // Copy right of Z to left of scratch (first kzs columns get + // multiplied by zero). + impl.Dlacpy(blas.All, jlen, knz, z[jrow*ldz+incol+1+j2:], ldz, wv[kzs:], ldwv) + + // Multiply by U12. + impl.Dlaset(blas.All, jlen, kzs, 0, 0, wv, ldwv) + bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.NonUnit, jlen, knz, + 1, u[j2*ldu+kzs:], ldu, wv[kzs:], ldwv) + + // Multiply by U11. + bi.Dgemm(blas.NoTrans, blas.NoTrans, jlen, i2, j2, + 1, z[jrow*ldz+incol+1:], ldz, u, ldu, + 1, wv, ldwv) + + // Copy left of Z to right of scratch. + impl.Dlacpy(blas.All, jlen, j2, z[jrow*ldz+incol+1:], ldz, wv[i2:], ldwv) + + // Multiply by U21. + bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.NonUnit, jlen, i4-i2, + 1, u[i2:], ldu, wv[i2:], ldwv) + + // Multiply by U22. + bi.Dgemm(blas.NoTrans, blas.NoTrans, jlen, i4-i2, j4-j2, + 1, z[jrow*ldz+incol+1+j2:], ldz, u[j2*ldu+i2:], ldu, + 1, wv[i2:], ldwv) + + // Copy the result back to Z. + impl.Dlacpy(blas.All, jlen, kdu, wv, ldwv, z[jrow*ldz+incol+1:], ldz) + } + } +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dlassq.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dlassq.go new file mode 100644 index 0000000000..3d982c3ccd --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dlassq.go @@ -0,0 +1,127 @@ +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import "math" + +// Dlassq updates a sum of squares represented in scaled form. Dlassq returns +// the values scl and smsq such that +// +// scl^2*smsq = X[0]^2 + ... + X[n-1]^2 + scale^2*sumsq +// +// The value of sumsq is assumed to be non-negative. +// +// Dlassq is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dlassq(n int, x []float64, incx int, scale float64, sumsq float64) (scl, smsq float64) { + // Implementation based on Supplemental Material to: + // Edward Anderson. 2017. Algorithm 978: Safe Scaling in the Level 1 BLAS. + // ACM Trans. Math. Softw. 44, 1, Article 12 (July 2017), 28 pages. + // DOI: https://doi.org/10.1145/3061665 + switch { + case n < 0: + panic(nLT0) + case incx <= 0: + panic(badIncX) + case len(x) < 1+(n-1)*incx: + panic(shortX) + } + + if math.IsNaN(scale) || math.IsNaN(sumsq) { + return scale, sumsq + } + + if sumsq == 0 { + scale = 1 + } + if scale == 0 { + scale = 1 + sumsq = 0 + } + + if n == 0 { + return scale, sumsq + } + + // Compute the sum of squares in 3 accumulators: + // - abig: sum of squares scaled down to avoid overflow + // - asml: sum of squares scaled up to avoid underflow + // - amed: sum of squares that do not require scaling + // The thresholds and multipliers are: + // - values bigger than dtbig are scaled down by dsbig + // - values smaller than dtsml are scaled up by dssml + var ( + isBig bool + asml, amed, abig float64 + ) + for i, ix := 0, 0; i < n; i++ { + ax := math.Abs(x[ix]) + switch { + case ax > dtbig: + ax *= dsbig + abig += ax * ax + isBig = true + case ax < dtsml: + if !isBig { + ax *= dssml + asml += ax * ax + } + default: + amed += ax * ax + } + ix += incx + } + // Put the existing sum of squares into one of the accumulators. + if sumsq > 0 { + ax := scale * math.Sqrt(sumsq) + switch { + case ax > dtbig: + // We assume scale >= sqrt( TINY*EPS ) / dsbig, that is, if the + // scaled sum is big then its scaling factor should not be too + // small. + v := scale * dsbig + abig += (v * v) * sumsq + case ax < dtsml: + if !isBig { + // We assume scale <= sqrt( HUGE ) / dssml, that is, if the + // scaled sum is small then its scaling factor should not be too + // big. + v := scale * dssml + asml += (v * v) * sumsq + } + default: + amed += ax * ax + } + } + // Combine abig and amed or amed and asml if more than one accumulator was + // used. + switch { + case abig > 0: + // Combine abig and amed: + if amed > 0 || math.IsNaN(amed) { + abig += (amed * dsbig) * dsbig + } + scale = 1 / dsbig + sumsq = abig + case asml > 0: + // Combine amed and asml: + if amed > 0 || math.IsNaN(amed) { + amed = math.Sqrt(amed) + asml = math.Sqrt(asml) / dssml + ymin, ymax := asml, amed + if asml > amed { + ymin, ymax = amed, asml + } + scale = 1 + sumsq = ymax * ymax * (1 + (ymin/ymax)*(ymin/ymax)) + } else { + scale = 1 / dssml + sumsq = asml + } + default: + scale = 1 + sumsq = amed + } + return scale, sumsq +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dlatbs.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dlatbs.go new file mode 100644 index 0000000000..19300faf86 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dlatbs.go @@ -0,0 +1,454 @@ +// Copyright ©2019 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" +) + +// Dlatbs solves a triangular banded system of equations +// +// A * x = s*b if trans == blas.NoTrans +// Aᵀ * x = s*b if trans == blas.Trans or blas.ConjTrans +// +// where A is an upper or lower triangular band matrix, x and b are n-element +// vectors, and s is a scaling factor chosen so that the components of x will be +// less than the overflow threshold. +// +// On entry, x contains the right-hand side b of the triangular system. +// On return, x is overwritten by the solution vector x. +// +// normin specifies whether the cnorm parameter contains the column norms of A on +// entry. If it is true, cnorm[j] contains the norm of the off-diagonal part of +// the j-th column of A. If it is false, the norms will be computed and stored +// in cnorm. +// +// Dlatbs returns the scaling factor s for the triangular system. If the matrix +// A is singular (A[j,j]==0 for some j), then scale is set to 0 and a +// non-trivial solution to A*x = 0 is returned. +// +// Dlatbs is an internal routine. It is exported for testing purposes. +func (Implementation) Dlatbs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n, kd int, ab []float64, ldab int, x, cnorm []float64) (scale float64) { + noTran := trans == blas.NoTrans + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case !noTran && trans != blas.Trans && trans != blas.ConjTrans: + panic(badTrans) + case diag != blas.NonUnit && diag != blas.Unit: + panic(badDiag) + case n < 0: + panic(nLT0) + case kd < 0: + panic(kdLT0) + case ldab < kd+1: + panic(badLdA) + } + + // Quick return if possible. + if n == 0 { + return 0 + } + + switch { + case len(ab) < (n-1)*ldab+kd+1: + panic(shortAB) + case len(x) < n: + panic(shortX) + case len(cnorm) < n: + panic(shortCNorm) + } + + // Parameters to control overflow. + smlnum := dlamchS / dlamchP + bignum := 1 / smlnum + + bi := blas64.Implementation() + kld := max(1, ldab-1) + if !normin { + // Compute the 1-norm of each column, not including the diagonal. + if uplo == blas.Upper { + for j := 0; j < n; j++ { + jlen := min(j, kd) + if jlen > 0 { + cnorm[j] = bi.Dasum(jlen, ab[(j-jlen)*ldab+jlen:], kld) + } else { + cnorm[j] = 0 + } + } + } else { + for j := 0; j < n; j++ { + jlen := min(n-j-1, kd) + if jlen > 0 { + cnorm[j] = bi.Dasum(jlen, ab[(j+1)*ldab+kd-1:], kld) + } else { + cnorm[j] = 0 + } + } + } + } + + // Set up indices and increments for loops below. + var ( + jFirst, jLast, jInc int + maind int + ) + if noTran { + if uplo == blas.Upper { + jFirst = n - 1 + jLast = -1 + jInc = -1 + maind = 0 + } else { + jFirst = 0 + jLast = n + jInc = 1 + maind = kd + } + } else { + if uplo == blas.Upper { + jFirst = 0 + jLast = n + jInc = 1 + maind = 0 + } else { + jFirst = n - 1 + jLast = -1 + jInc = -1 + maind = kd + } + } + + // Scale the column norms by tscal if the maximum element in cnorm is + // greater than bignum. + tmax := cnorm[bi.Idamax(n, cnorm, 1)] + tscal := 1.0 + if tmax > bignum { + tscal = 1 / (smlnum * tmax) + bi.Dscal(n, tscal, cnorm, 1) + } + + // Compute a bound on the computed solution vector to see if the Level 2 + // BLAS routine Dtbsv can be used. + + xMax := math.Abs(x[bi.Idamax(n, x, 1)]) + xBnd := xMax + grow := 0.0 + // Compute the growth only if the maximum element in cnorm is NOT greater + // than bignum. + if tscal != 1 { + goto skipComputeGrow + } + if noTran { + // Compute the growth in A * x = b. + if diag == blas.NonUnit { + // A is non-unit triangular. + // + // Compute grow = 1/G_j and xBnd = 1/M_j. + // Initially, G_0 = max{x(i), i=1,...,n}. + grow = 1 / math.Max(xBnd, smlnum) + xBnd = grow + for j := jFirst; j != jLast; j += jInc { + if grow <= smlnum { + // Exit the loop because the growth factor is too small. + goto skipComputeGrow + } + // M_j = G_{j-1} / abs(A[j,j]) + tjj := math.Abs(ab[j*ldab+maind]) + xBnd = math.Min(xBnd, math.Min(1, tjj)*grow) + if tjj+cnorm[j] >= smlnum { + // G_j = G_{j-1}*( 1 + cnorm[j] / abs(A[j,j]) ) + grow *= tjj / (tjj + cnorm[j]) + } else { + // G_j could overflow, set grow to 0. + grow = 0 + } + } + grow = xBnd + } else { + // A is unit triangular. + // + // Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}. + grow = math.Min(1, 1/math.Max(xBnd, smlnum)) + for j := jFirst; j != jLast; j += jInc { + if grow <= smlnum { + // Exit the loop because the growth factor is too small. + goto skipComputeGrow + } + // G_j = G_{j-1}*( 1 + cnorm[j] ) + grow /= 1 + cnorm[j] + } + } + } else { + // Compute the growth in Aᵀ * x = b. + if diag == blas.NonUnit { + // A is non-unit triangular. + // + // Compute grow = 1/G_j and xBnd = 1/M_j. + // Initially, G_0 = max{x(i), i=1,...,n}. + grow = 1 / math.Max(xBnd, smlnum) + xBnd = grow + for j := jFirst; j != jLast; j += jInc { + if grow <= smlnum { + // Exit the loop because the growth factor is too small. + goto skipComputeGrow + } + // G_j = max( G_{j-1}, M_{j-1}*( 1 + cnorm[j] ) ) + xj := 1 + cnorm[j] + grow = math.Min(grow, xBnd/xj) + // M_j = M_{j-1}*( 1 + cnorm[j] ) / abs(A[j,j]) + tjj := math.Abs(ab[j*ldab+maind]) + if xj > tjj { + xBnd *= tjj / xj + } + } + grow = math.Min(grow, xBnd) + } else { + // A is unit triangular. + // + // Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}. + grow = math.Min(1, 1/math.Max(xBnd, smlnum)) + for j := jFirst; j != jLast; j += jInc { + if grow <= smlnum { + // Exit the loop because the growth factor is too small. + goto skipComputeGrow + } + // G_j = G_{j-1}*( 1 + cnorm[j] ) + grow /= 1 + cnorm[j] + } + } + } +skipComputeGrow: + + if grow*tscal > smlnum { + // The reciprocal of the bound on elements of X is not too small, use + // the Level 2 BLAS solve. + bi.Dtbsv(uplo, trans, diag, n, kd, ab, ldab, x, 1) + // Scale the column norms by 1/tscal for return. + if tscal != 1 { + bi.Dscal(n, 1/tscal, cnorm, 1) + } + return 1 + } + + // Use a Level 1 BLAS solve, scaling intermediate results. + + scale = 1 + if xMax > bignum { + // Scale x so that its components are less than or equal to bignum in + // absolute value. + scale = bignum / xMax + bi.Dscal(n, scale, x, 1) + xMax = bignum + } + + if noTran { + // Solve A * x = b. + for j := jFirst; j != jLast; j += jInc { + // Compute x[j] = b[j] / A[j,j], scaling x if necessary. + xj := math.Abs(x[j]) + tjjs := tscal + if diag == blas.NonUnit { + tjjs *= ab[j*ldab+maind] + } + tjj := math.Abs(tjjs) + switch { + case tjj > smlnum: + // smlnum < abs(A[j,j]) + if tjj < 1 && xj > tjj*bignum { + // Scale x by 1/b[j]. + rec := 1 / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xMax *= rec + } + x[j] /= tjjs + xj = math.Abs(x[j]) + case tjj > 0: + // 0 < abs(A[j,j]) <= smlnum + if xj > tjj*bignum { + // Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum to avoid + // overflow when dividing by A[j,j]. + rec := tjj * bignum / xj + if cnorm[j] > 1 { + // Scale by 1/cnorm[j] to avoid overflow when + // multiplying x[j] times column j. + rec /= cnorm[j] + } + bi.Dscal(n, rec, x, 1) + scale *= rec + xMax *= rec + } + x[j] /= tjjs + xj = math.Abs(x[j]) + default: + // A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and + // compute a solution to A*x = 0. + for i := range x[:n] { + x[i] = 0 + } + x[j] = 1 + xj = 1 + scale = 0 + xMax = 0 + } + + // Scale x if necessary to avoid overflow when adding a multiple of + // column j of A. + switch { + case xj > 1: + rec := 1 / xj + if cnorm[j] > (bignum-xMax)*rec { + // Scale x by 1/(2*abs(x[j])). + rec *= 0.5 + bi.Dscal(n, rec, x, 1) + scale *= rec + } + case xj*cnorm[j] > bignum-xMax: + // Scale x by 1/2. + bi.Dscal(n, 0.5, x, 1) + scale *= 0.5 + } + + if uplo == blas.Upper { + if j > 0 { + // Compute the update + // x[max(0,j-kd):j] := x[max(0,j-kd):j] - x[j] * A[max(0,j-kd):j,j] + jlen := min(j, kd) + if jlen > 0 { + bi.Daxpy(jlen, -x[j]*tscal, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1) + } + i := bi.Idamax(j, x, 1) + xMax = math.Abs(x[i]) + } + } else if j < n-1 { + // Compute the update + // x[j+1:min(j+kd,n)] := x[j+1:min(j+kd,n)] - x[j] * A[j+1:min(j+kd,n),j] + jlen := min(kd, n-j-1) + if jlen > 0 { + bi.Daxpy(jlen, -x[j]*tscal, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1) + } + i := j + 1 + bi.Idamax(n-j-1, x[j+1:], 1) + xMax = math.Abs(x[i]) + } + } + } else { + // Solve Aᵀ * x = b. + for j := jFirst; j != jLast; j += jInc { + // Compute x[j] = b[j] - sum A[k,j]*x[k]. + // k!=j + xj := math.Abs(x[j]) + tjjs := tscal + if diag == blas.NonUnit { + tjjs *= ab[j*ldab+maind] + } + tjj := math.Abs(tjjs) + rec := 1 / math.Max(1, xMax) + uscal := tscal + if cnorm[j] > (bignum-xj)*rec { + // If x[j] could overflow, scale x by 1/(2*xMax). + rec *= 0.5 + if tjj > 1 { + // Divide by A[j,j] when scaling x if A[j,j] > 1. + rec = math.Min(1, rec*tjj) + uscal /= tjjs + } + if rec < 1 { + bi.Dscal(n, rec, x, 1) + scale *= rec + xMax *= rec + } + } + + var sumj float64 + if uscal == 1 { + // If the scaling needed for A in the dot product is 1, call + // Ddot to perform the dot product... + if uplo == blas.Upper { + jlen := min(j, kd) + if jlen > 0 { + sumj = bi.Ddot(jlen, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1) + } + } else { + jlen := min(n-j-1, kd) + if jlen > 0 { + sumj = bi.Ddot(jlen, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1) + } + } + } else { + // ...otherwise, use in-line code for the dot product. + if uplo == blas.Upper { + jlen := min(j, kd) + for i := 0; i < jlen; i++ { + sumj += (ab[(j-jlen+i)*ldab+jlen-i] * uscal) * x[j-jlen+i] + } + } else { + jlen := min(n-j-1, kd) + for i := 0; i < jlen; i++ { + sumj += (ab[(j+1+i)*ldab+kd-1-i] * uscal) * x[j+i+1] + } + } + } + + if uscal == tscal { + // Compute x[j] := ( x[j] - sumj ) / A[j,j] + // if 1/A[j,j] was not used to scale the dot product. + x[j] -= sumj + xj = math.Abs(x[j]) + // Compute x[j] = x[j] / A[j,j], scaling if necessary. + // Note: the reference implementation skips this step for blas.Unit matrices + // when tscal is equal to 1 but it complicates the logic and only saves + // the comparison and division in the first switch-case. Not skipping it + // is also consistent with the NoTrans case above. + switch { + case tjj > smlnum: + // smlnum < abs(A[j,j]): + if tjj < 1 && xj > tjj*bignum { + // Scale x by 1/abs(x[j]). + rec := 1 / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xMax *= rec + } + x[j] /= tjjs + case tjj > 0: + // 0 < abs(A[j,j]) <= smlnum: + if xj > tjj*bignum { + // Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum. + rec := (tjj * bignum) / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xMax *= rec + } + x[j] /= tjjs + default: + // A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and + // compute a solution Aᵀ * x = 0. + for i := range x[:n] { + x[i] = 0 + } + x[j] = 1 + scale = 0 + xMax = 0 + } + } else { + // Compute x[j] := x[j] / A[j,j] - sumj + // if the dot product has already been divided by 1/A[j,j]. + x[j] = x[j]/tjjs - sumj + } + xMax = math.Max(xMax, math.Abs(x[j])) + } + scale /= tscal + } + + // Scale the column norms by 1/tscal for return. + if tscal != 1 { + bi.Dscal(n, 1/tscal, cnorm, 1) + } + return scale +} diff --git a/vendor/gonum.org/v1/gonum/lapack/gonum/dlatrs.go b/vendor/gonum.org/v1/gonum/lapack/gonum/dlatrs.go new file mode 100644 index 0000000000..37ac2fe70a --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/gonum/dlatrs.go @@ -0,0 +1,361 @@ +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gonum + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" +) + +// Dlatrs solves a triangular system of equations scaled to prevent overflow. It +// solves +// +// A * x = scale * b if trans == blas.NoTrans +// Aᵀ * x = scale * b if trans == blas.Trans +// +// where the scale s is set for numeric stability. +// +// A is an n×n triangular matrix. On entry, the slice x contains the values of +// b, and on exit it contains the solution vector x. +// +// If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal +// part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater +// than or equal to the infinity norm, and greater than or equal to the one-norm +// otherwise. If normin == false, then cnorm is treated as an output, and is set +// to contain the 1-norm of the off-diagonal part of the j^th column of A. +// +// Dlatrs is an internal routine. It is exported for testing purposes. +func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) { + switch { + case uplo != blas.Upper && uplo != blas.Lower: + panic(badUplo) + case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: + panic(badTrans) + case diag != blas.Unit && diag != blas.NonUnit: + panic(badDiag) + case n < 0: + panic(nLT0) + case lda < max(1, n): + panic(badLdA) + } + + // Quick return if possible. + if n == 0 { + return 0 + } + + switch { + case len(a) < (n-1)*lda+n: + panic(shortA) + case len(x) < n: + panic(shortX) + case len(cnorm) < n: + panic(shortCNorm) + } + + upper := uplo == blas.Upper + nonUnit := diag == blas.NonUnit + + smlnum := dlamchS / dlamchP + bignum := 1 / smlnum + scale = 1 + + bi := blas64.Implementation() + + if !normin { + if upper { + cnorm[0] = 0 + for j := 1; j < n; j++ { + cnorm[j] = bi.Dasum(j, a[j:], lda) + } + } else { + for j := 0; j < n-1; j++ { + cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda) + } + cnorm[n-1] = 0 + } + } + // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum. + imax := bi.Idamax(n, cnorm, 1) + tmax := cnorm[imax] + var tscal float64 + if tmax <= bignum { + tscal = 1 + } else { + tscal = 1 / (smlnum * tmax) + bi.Dscal(n, tscal, cnorm, 1) + } + + // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used. + j := bi.Idamax(n, x, 1) + xmax := math.Abs(x[j]) + xbnd := xmax + var grow float64 + var jfirst, jlast, jinc int + if trans == blas.NoTrans { + if upper { + jfirst = n - 1 + jlast = -1 + jinc = -1 + } else { + jfirst = 0 + jlast = n + jinc = 1 + } + // Compute the growth in A * x = b. + if tscal != 1 { + grow = 0 + goto Solve + } + if nonUnit { + grow = 1 / math.Max(xbnd, smlnum) + xbnd = grow + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Solve + } + tjj := math.Abs(a[j*lda+j]) + xbnd = math.Min(xbnd, math.Min(1, tjj)*grow) + if tjj+cnorm[j] >= smlnum { + grow *= tjj / (tjj + cnorm[j]) + } else { + grow = 0 + } + } + grow = xbnd + } else { + grow = math.Min(1, 1/math.Max(xbnd, smlnum)) + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Solve + } + grow *= 1 / (1 + cnorm[j]) + } + } + } else { + if upper { + jfirst = 0 + jlast = n + jinc = 1 + } else { + jfirst = n - 1 + jlast = -1 + jinc = -1 + } + if tscal != 1 { + grow = 0 + goto Solve + } + if nonUnit { + grow = 1 / (math.Max(xbnd, smlnum)) + xbnd = grow + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Solve + } + xj := 1 + cnorm[j] + grow = math.Min(grow, xbnd/xj) + tjj := math.Abs(a[j*lda+j]) + if xj > tjj { + xbnd *= tjj / xj + } + } + grow = math.Min(grow, xbnd) + } else { + grow = math.Min(1, 1/math.Max(xbnd, smlnum)) + for j := jfirst; j != jlast; j += jinc { + if grow <= smlnum { + goto Solve + } + xj := 1 + cnorm[j] + grow /= xj + } + } + } + +Solve: + if grow*tscal > smlnum { + // Use the Level 2 BLAS solve if the reciprocal of the bound on + // elements of X is not too small. + bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1) + if tscal != 1 { + bi.Dscal(n, 1/tscal, cnorm, 1) + } + return scale + } + + // Use a Level 1 BLAS solve, scaling intermediate results. + if xmax > bignum { + scale = bignum / xmax + bi.Dscal(n, scale, x, 1) + xmax = bignum + } + if trans == blas.NoTrans { + for j := jfirst; j != jlast; j += jinc { + xj := math.Abs(x[j]) + var tjj, tjjs float64 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + if tscal == 1 { + goto Skip1 + } + } + tjj = math.Abs(tjjs) + if tjj > smlnum { + if tjj < 1 { + if xj > tjj*bignum { + rec := 1 / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + x[j] /= tjjs + xj = math.Abs(x[j]) + } else if tjj > 0 { + if xj > tjj*bignum { + rec := (tjj * bignum) / xj + if cnorm[j] > 1 { + rec /= cnorm[j] + } + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + x[j] /= tjjs + xj = math.Abs(x[j]) + } else { + for i := 0; i < n; i++ { + x[i] = 0 + } + x[j] = 1 + xj = 1 + scale = 0 + xmax = 0 + } + Skip1: + if xj > 1 { + rec := 1 / xj + if cnorm[j] > (bignum-xmax)*rec { + rec *= 0.5 + bi.Dscal(n, rec, x, 1) + scale *= rec + } + } else if xj*cnorm[j] > bignum-xmax { + bi.Dscal(n, 0.5, x, 1) + scale *= 0.5 + } + if upper { + if j > 0 { + bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1) + i := bi.Idamax(j, x, 1) + xmax = math.Abs(x[i]) + } + } else { + if j < n-1 { + bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1) + i := j + bi.Idamax(n-j-1, x[j+1:], 1) + xmax = math.Abs(x[i]) + } + } + } + } else { + for j := jfirst; j != jlast; j += jinc { + xj := math.Abs(x[j]) + uscal := tscal + rec := 1 / math.Max(xmax, 1) + var tjjs float64 + if cnorm[j] > (bignum-xj)*rec { + rec *= 0.5 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + } + tjj := math.Abs(tjjs) + if tjj > 1 { + rec = math.Min(1, rec*tjj) + uscal /= tjjs + } + if rec < 1 { + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + var sumj float64 + if uscal == 1 { + if upper { + sumj = bi.Ddot(j, a[j:], lda, x, 1) + } else if j < n-1 { + sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1) + } + } else { + if upper { + for i := 0; i < j; i++ { + sumj += (a[i*lda+j] * uscal) * x[i] + } + } else if j < n { + for i := j + 1; i < n; i++ { + sumj += (a[i*lda+j] * uscal) * x[i] + } + } + } + if uscal == tscal { + x[j] -= sumj + xj := math.Abs(x[j]) + var tjjs float64 + if nonUnit { + tjjs = a[j*lda+j] * tscal + } else { + tjjs = tscal + if tscal == 1 { + goto Skip2 + } + } + tjj := math.Abs(tjjs) + if tjj > smlnum { + if tjj < 1 { + if xj > tjj*bignum { + rec = 1 / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + } + x[j] /= tjjs + } else if tjj > 0 { + if xj > tjj*bignum { + rec = (tjj * bignum) / xj + bi.Dscal(n, rec, x, 1) + scale *= rec + xmax *= rec + } + x[j] /= tjjs + } else { + for i := 0; i < n; i++ { + x[i] = 0 + } + x[j] = 1 + scale = 0 + xmax = 0 + } + } else { + x[j] = x[j]/tjjs - sumj + } + Skip2: + xmax = math.Max(xmax, math.Abs(x[j])) + } + } + scale /= tscal + if tscal != 1 { + bi.Dscal(n, 1/tscal, cnorm, 1) + } + return scale +} diff --git a/vendor/gonum.org/v1/gonum/lapack/lapack64/lapack64.go b/vendor/gonum.org/v1/gonum/lapack/lapack64/lapack64.go new file mode 100644 index 0000000000..acb62da4a0 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/lapack/lapack64/lapack64.go @@ -0,0 +1,826 @@ +// Copyright ©2015 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package lapack64 + +import ( + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack" + "gonum.org/v1/gonum/lapack/gonum" +) + +var lapack64 lapack.Float64 = gonum.Implementation{} + +// Use sets the LAPACK float64 implementation to be used by subsequent BLAS calls. +// The default implementation is native.Implementation. +func Use(l lapack.Float64) { + lapack64 = l +} + +// Tridiagonal represents a tridiagonal matrix using its three diagonals. +type Tridiagonal struct { + N int + DL []float64 + D []float64 + DU []float64 +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +// Potrf computes the Cholesky factorization of a. +// The factorization has the form +// +// A = Uᵀ * U if a.Uplo == blas.Upper, or +// A = L * Lᵀ if a.Uplo == blas.Lower, +// +// where U is an upper triangular matrix and L is lower triangular. +// The triangular matrix is returned in t, and the underlying data between +// a and t is shared. The returned bool indicates whether a is positive +// definite and the factorization could be finished. +func Potrf(a blas64.Symmetric) (t blas64.Triangular, ok bool) { + ok = lapack64.Dpotrf(a.Uplo, a.N, a.Data, max(1, a.Stride)) + t.Uplo = a.Uplo + t.N = a.N + t.Data = a.Data + t.Stride = a.Stride + t.Diag = blas.NonUnit + return +} + +// Potri computes the inverse of a real symmetric positive definite matrix A +// using its Cholesky factorization. +// +// On entry, t contains the triangular factor U or L from the Cholesky +// factorization A = Uᵀ*U or A = L*Lᵀ, as computed by Potrf. +// +// On return, the upper or lower triangle of the (symmetric) inverse of A is +// stored in t, overwriting the input factor U or L, and also returned in a. The +// underlying data between a and t is shared. +// +// The returned bool indicates whether the inverse was computed successfully. +func Potri(t blas64.Triangular) (a blas64.Symmetric, ok bool) { + ok = lapack64.Dpotri(t.Uplo, t.N, t.Data, max(1, t.Stride)) + a.Uplo = t.Uplo + a.N = t.N + a.Data = t.Data + a.Stride = t.Stride + return +} + +// Potrs solves a system of n linear equations A*X = B where A is an n×n +// symmetric positive definite matrix and B is an n×nrhs matrix, using the +// Cholesky factorization A = Uᵀ*U or A = L*Lᵀ. t contains the corresponding +// triangular factor as returned by Potrf. On entry, B contains the right-hand +// side matrix B, on return it contains the solution matrix X. +func Potrs(t blas64.Triangular, b blas64.General) { + lapack64.Dpotrs(t.Uplo, t.N, b.Cols, t.Data, max(1, t.Stride), b.Data, max(1, b.Stride)) +} + +// Pbcon returns an estimate of the reciprocal of the condition number (in the +// 1-norm) of an n×n symmetric positive definite band matrix using the Cholesky +// factorization +// +// A = Uᵀ*U if uplo == blas.Upper +// A = L*Lᵀ if uplo == blas.Lower +// +// computed by Pbtrf. The estimate is obtained for norm(inv(A)), and the +// reciprocal of the condition number is computed as +// +// rcond = 1 / (anorm * norm(inv(A))). +// +// The length of work must be at least 3*n and the length of iwork must be at +// least n. +func Pbcon(a blas64.SymmetricBand, anorm float64, work []float64, iwork []int) float64 { + return lapack64.Dpbcon(a.Uplo, a.N, a.K, a.Data, a.Stride, anorm, work, iwork) +} + +// Pbtrf computes the Cholesky factorization of an n×n symmetric positive +// definite band matrix +// +// A = Uᵀ * U if a.Uplo == blas.Upper +// A = L * Lᵀ if a.Uplo == blas.Lower +// +// where U and L are upper, respectively lower, triangular band matrices. +// +// The triangular matrix U or L is returned in t, and the underlying data +// between a and t is shared. The returned bool indicates whether A is positive +// definite and the factorization could be finished. +func Pbtrf(a blas64.SymmetricBand) (t blas64.TriangularBand, ok bool) { + ok = lapack64.Dpbtrf(a.Uplo, a.N, a.K, a.Data, max(1, a.Stride)) + t.Uplo = a.Uplo + t.Diag = blas.NonUnit + t.N = a.N + t.K = a.K + t.Data = a.Data + t.Stride = a.Stride + return t, ok +} + +// Pbtrs solves a system of linear equations A*X = B with an n×n symmetric +// positive definite band matrix A using the Cholesky factorization +// +// A = Uᵀ * U if t.Uplo == blas.Upper +// A = L * Lᵀ if t.Uplo == blas.Lower +// +// t contains the corresponding triangular factor as returned by Pbtrf. +// +// On entry, b contains the right hand side matrix B. On return, it is +// overwritten with the solution matrix X. +func Pbtrs(t blas64.TriangularBand, b blas64.General) { + lapack64.Dpbtrs(t.Uplo, t.N, t.K, b.Cols, t.Data, max(1, t.Stride), b.Data, max(1, b.Stride)) +} + +// Pstrf computes the Cholesky factorization with complete pivoting of an n×n +// symmetric positive semidefinite matrix A. +// +// The factorization has the form +// +// Pᵀ * A * P = Uᵀ * U , if a.Uplo = blas.Upper, +// Pᵀ * A * P = L * Lᵀ, if a.Uplo = blas.Lower, +// +// where U is an upper triangular matrix, L is lower triangular, and P is a +// permutation matrix. +// +// tol is a user-defined tolerance. The algorithm terminates if the pivot is +// less than or equal to tol. If tol is negative, then n*eps*max(A[k,k]) will be +// used instead. +// +// The triangular factor U or L from the Cholesky factorization is returned in t +// and the underlying data between a and t is shared. P is stored on return in +// vector piv such that P[piv[k],k] = 1. +// +// Pstrf returns the computed rank of A and whether the factorization can be +// used to solve a system. Pstrf does not attempt to check that A is positive +// semi-definite, so if ok is false, the matrix A is either rank deficient or is +// not positive semidefinite. +// +// The length of piv must be n and the length of work must be at least 2*n, +// otherwise Pstrf will panic. +func Pstrf(a blas64.Symmetric, piv []int, tol float64, work []float64) (t blas64.Triangular, rank int, ok bool) { + rank, ok = lapack64.Dpstrf(a.Uplo, a.N, a.Data, max(1, a.Stride), piv, tol, work) + t.Uplo = a.Uplo + t.Diag = blas.NonUnit + t.N = a.N + t.Data = a.Data + t.Stride = a.Stride + return t, rank, ok +} + +// Gecon estimates the reciprocal of the condition number of the n×n matrix A +// given the LU decomposition of the matrix. The condition number computed may +// be based on the 1-norm or the ∞-norm. +// +// a contains the result of the LU decomposition of A as computed by Getrf. +// +// anorm is the corresponding 1-norm or ∞-norm of the original matrix A. +// +// work is a temporary data slice of length at least 4*n and Gecon will panic otherwise. +// +// iwork is a temporary data slice of length at least n and Gecon will panic otherwise. +func Gecon(norm lapack.MatrixNorm, a blas64.General, anorm float64, work []float64, iwork []int) float64 { + return lapack64.Dgecon(norm, a.Cols, a.Data, max(1, a.Stride), anorm, work, iwork) +} + +// Gels finds a minimum-norm solution based on the matrices A and B using the +// QR or LQ factorization. Gels returns false if the matrix +// A is singular, and true if this solution was successfully found. +// +// The minimization problem solved depends on the input parameters. +// +// 1. If m >= n and trans == blas.NoTrans, Gels finds X such that || A*X - B||_2 +// is minimized. +// 2. If m < n and trans == blas.NoTrans, Gels finds the minimum norm solution of +// A * X = B. +// 3. If m >= n and trans == blas.Trans, Gels finds the minimum norm solution of +// Aᵀ * X = B. +// 4. If m < n and trans == blas.Trans, Gels finds X such that || A*X - B||_2 +// is minimized. +// +// Note that the least-squares solutions (cases 1 and 3) perform the minimization +// per column of B. This is not the same as finding the minimum-norm matrix. +// +// The matrix A is a general matrix of size m×n and is modified during this call. +// The input matrix B is of size max(m,n)×nrhs, and serves two purposes. On entry, +// the elements of b specify the input matrix B. B has size m×nrhs if +// trans == blas.NoTrans, and n×nrhs if trans == blas.Trans. On exit, the +// leading submatrix of b contains the solution vectors X. If trans == blas.NoTrans, +// this submatrix is of size n×nrhs, and of size m×nrhs otherwise. +// +// Work is temporary storage, and lwork specifies the usable memory length. +// At minimum, lwork >= max(m,n) + max(m,n,nrhs), and this function will panic +// otherwise. A longer work will enable blocked algorithms to be called. +// In the special case that lwork == -1, work[0] will be set to the optimal working +// length. +func Gels(trans blas.Transpose, a blas64.General, b blas64.General, work []float64, lwork int) bool { + return lapack64.Dgels(trans, a.Rows, a.Cols, b.Cols, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride), work, lwork) +} + +// Geqrf computes the QR factorization of the m×n matrix A using a blocked +// algorithm. A is modified to contain the information to construct Q and R. +// The upper triangle of a contains the matrix R. The lower triangular elements +// (not including the diagonal) contain the elementary reflectors. tau is modified +// to contain the reflector scales. tau must have length at least min(m,n), and +// this function will panic otherwise. +// +// The ith elementary reflector can be explicitly constructed by first extracting +// the +// +// v[j] = 0 j < i +// v[j] = 1 j == i +// v[j] = a[j*lda+i] j > i +// +// and computing H_i = I - tau[i] * v * vᵀ. +// +// The orthonormal matrix Q can be constucted from a product of these elementary +// reflectors, Q = H_0 * H_1 * ... * H_{k-1}, where k = min(m,n). +// +// Work is temporary storage, and lwork specifies the usable memory length. +// At minimum, lwork >= m and this function will panic otherwise. +// Geqrf is a blocked QR factorization, but the block size is limited +// by the temporary space available. If lwork == -1, instead of performing Geqrf, +// the optimal work length will be stored into work[0]. +func Geqrf(a blas64.General, tau, work []float64, lwork int) { + lapack64.Dgeqrf(a.Rows, a.Cols, a.Data, max(1, a.Stride), tau, work, lwork) +} + +// Gelqf computes the LQ factorization of the m×n matrix A using a blocked +// algorithm. A is modified to contain the information to construct L and Q. The +// lower triangle of a contains the matrix L. The elements above the diagonal +// and the slice tau represent the matrix Q. tau is modified to contain the +// reflector scales. tau must have length at least min(m,n), and this function +// will panic otherwise. +// +// See Geqrf for a description of the elementary reflectors and orthonormal +// matrix Q. Q is constructed as a product of these elementary reflectors, +// Q = H_{k-1} * ... * H_1 * H_0. +// +// Work is temporary storage, and lwork specifies the usable memory length. +// At minimum, lwork >= m and this function will panic otherwise. +// Gelqf is a blocked LQ factorization, but the block size is limited +// by the temporary space available. If lwork == -1, instead of performing Gelqf, +// the optimal work length will be stored into work[0]. +func Gelqf(a blas64.General, tau, work []float64, lwork int) { + lapack64.Dgelqf(a.Rows, a.Cols, a.Data, max(1, a.Stride), tau, work, lwork) +} + +// Gesvd computes the singular value decomposition of the input matrix A. +// +// The singular value decomposition is +// +// A = U * Sigma * Vᵀ +// +// where Sigma is an m×n diagonal matrix containing the singular values of A, +// U is an m×m orthogonal matrix and V is an n×n orthogonal matrix. The first +// min(m,n) columns of U and V are the left and right singular vectors of A +// respectively. +// +// jobU and jobVT are options for computing the singular vectors. The behavior +// is as follows +// +// jobU == lapack.SVDAll All m columns of U are returned in u +// jobU == lapack.SVDStore The first min(m,n) columns are returned in u +// jobU == lapack.SVDOverwrite The first min(m,n) columns of U are written into a +// jobU == lapack.SVDNone The columns of U are not computed. +// +// The behavior is the same for jobVT and the rows of Vᵀ. At most one of jobU +// and jobVT can equal lapack.SVDOverwrite, and Gesvd will panic otherwise. +// +// On entry, a contains the data for the m×n matrix A. During the call to Gesvd +// the data is overwritten. On exit, A contains the appropriate singular vectors +// if either job is lapack.SVDOverwrite. +// +// s is a slice of length at least min(m,n) and on exit contains the singular +// values in decreasing order. +// +// u contains the left singular vectors on exit, stored columnwise. If +// jobU == lapack.SVDAll, u is of size m×m. If jobU == lapack.SVDStore u is +// of size m×min(m,n). If jobU == lapack.SVDOverwrite or lapack.SVDNone, u is +// not used. +// +// vt contains the left singular vectors on exit, stored rowwise. If +// jobV == lapack.SVDAll, vt is of size n×m. If jobVT == lapack.SVDStore vt is +// of size min(m,n)×n. If jobVT == lapack.SVDOverwrite or lapack.SVDNone, vt is +// not used. +// +// work is a slice for storing temporary memory, and lwork is the usable size of +// the slice. lwork must be at least max(5*min(m,n), 3*min(m,n)+max(m,n)). +// If lwork == -1, instead of performing Gesvd, the optimal work length will be +// stored into work[0]. Gesvd will panic if the working memory has insufficient +// storage. +// +// Gesvd returns whether the decomposition successfully completed. +func Gesvd(jobU, jobVT lapack.SVDJob, a, u, vt blas64.General, s, work []float64, lwork int) (ok bool) { + return lapack64.Dgesvd(jobU, jobVT, a.Rows, a.Cols, a.Data, max(1, a.Stride), s, u.Data, max(1, u.Stride), vt.Data, max(1, vt.Stride), work, lwork) +} + +// Getrf computes the LU decomposition of the m×n matrix A. +// The LU decomposition is a factorization of A into +// +// A = P * L * U +// +// where P is a permutation matrix, L is a unit lower triangular matrix, and +// U is a (usually) non-unit upper triangular matrix. On exit, L and U are stored +// in place into a. +// +// ipiv is a permutation vector. It indicates that row i of the matrix was +// changed with ipiv[i]. ipiv must have length at least min(m,n), and will panic +// otherwise. ipiv is zero-indexed. +// +// Getrf is the blocked version of the algorithm. +// +// Getrf returns whether the matrix A is singular. The LU decomposition will +// be computed regardless of the singularity of A, but division by zero +// will occur if the false is returned and the result is used to solve a +// system of equations. +func Getrf(a blas64.General, ipiv []int) bool { + return lapack64.Dgetrf(a.Rows, a.Cols, a.Data, max(1, a.Stride), ipiv) +} + +// Getri computes the inverse of the matrix A using the LU factorization computed +// by Getrf. On entry, a contains the PLU decomposition of A as computed by +// Getrf and on exit contains the reciprocal of the original matrix. +// +// Getri will not perform the inversion if the matrix is singular, and returns +// a boolean indicating whether the inversion was successful. +// +// Work is temporary storage, and lwork specifies the usable memory length. +// At minimum, lwork >= n and this function will panic otherwise. +// Getri is a blocked inversion, but the block size is limited +// by the temporary space available. If lwork == -1, instead of performing Getri, +// the optimal work length will be stored into work[0]. +func Getri(a blas64.General, ipiv []int, work []float64, lwork int) (ok bool) { + return lapack64.Dgetri(a.Cols, a.Data, max(1, a.Stride), ipiv, work, lwork) +} + +// Getrs solves a system of equations using an LU factorization. +// The system of equations solved is +// +// A * X = B if trans == blas.Trans +// Aᵀ * X = B if trans == blas.NoTrans +// +// A is a general n×n matrix with stride lda. B is a general matrix of size n×nrhs. +// +// On entry b contains the elements of the matrix B. On exit, b contains the +// elements of X, the solution to the system of equations. +// +// a and ipiv contain the LU factorization of A and the permutation indices as +// computed by Getrf. ipiv is zero-indexed. +func Getrs(trans blas.Transpose, a blas64.General, b blas64.General, ipiv []int) { + lapack64.Dgetrs(trans, a.Cols, b.Cols, a.Data, max(1, a.Stride), ipiv, b.Data, max(1, b.Stride)) +} + +// Ggsvd3 computes the generalized singular value decomposition (GSVD) +// of an m×n matrix A and p×n matrix B: +// +// Uᵀ*A*Q = D1*[ 0 R ] +// +// Vᵀ*B*Q = D2*[ 0 R ] +// +// where U, V and Q are orthogonal matrices. +// +// Ggsvd3 returns k and l, the dimensions of the sub-blocks. k+l +// is the effective numerical rank of the (m+p)×n matrix [ Aᵀ Bᵀ ]ᵀ. +// R is a (k+l)×(k+l) nonsingular upper triangular matrix, D1 and +// D2 are m×(k+l) and p×(k+l) diagonal matrices and of the following +// structures, respectively: +// +// If m-k-l >= 0, +// +// k l +// D1 = k [ I 0 ] +// l [ 0 C ] +// m-k-l [ 0 0 ] +// +// k l +// D2 = l [ 0 S ] +// p-l [ 0 0 ] +// +// n-k-l k l +// [ 0 R ] = k [ 0 R11 R12 ] k +// l [ 0 0 R22 ] l +// +// where +// +// C = diag( alpha_k, ... , alpha_{k+l} ), +// S = diag( beta_k, ... , beta_{k+l} ), +// C^2 + S^2 = I. +// +// R is stored in +// +// A[0:k+l, n-k-l:n] +// +// on exit. +// +// If m-k-l < 0, +// +// k m-k k+l-m +// D1 = k [ I 0 0 ] +// m-k [ 0 C 0 ] +// +// k m-k k+l-m +// D2 = m-k [ 0 S 0 ] +// k+l-m [ 0 0 I ] +// p-l [ 0 0 0 ] +// +// n-k-l k m-k k+l-m +// [ 0 R ] = k [ 0 R11 R12 R13 ] +// m-k [ 0 0 R22 R23 ] +// k+l-m [ 0 0 0 R33 ] +// +// where +// +// C = diag( alpha_k, ... , alpha_m ), +// S = diag( beta_k, ... , beta_m ), +// C^2 + S^2 = I. +// +// R = [ R11 R12 R13 ] is stored in A[1:m, n-k-l+1:n] +// [ 0 R22 R23 ] +// +// and R33 is stored in +// +// B[m-k:l, n+m-k-l:n] on exit. +// +// Ggsvd3 computes C, S, R, and optionally the orthogonal transformation +// matrices U, V and Q. +// +// jobU, jobV and jobQ are options for computing the orthogonal matrices. The behavior +// is as follows +// +// jobU == lapack.GSVDU Compute orthogonal matrix U +// jobU == lapack.GSVDNone Do not compute orthogonal matrix. +// +// The behavior is the same for jobV and jobQ with the exception that instead of +// lapack.GSVDU these accept lapack.GSVDV and lapack.GSVDQ respectively. +// The matrices U, V and Q must be m×m, p×p and n×n respectively unless the +// relevant job parameter is lapack.GSVDNone. +// +// alpha and beta must have length n or Ggsvd3 will panic. On exit, alpha and +// beta contain the generalized singular value pairs of A and B +// +// alpha[0:k] = 1, +// beta[0:k] = 0, +// +// if m-k-l >= 0, +// +// alpha[k:k+l] = diag(C), +// beta[k:k+l] = diag(S), +// +// if m-k-l < 0, +// +// alpha[k:m]= C, alpha[m:k+l]= 0 +// beta[k:m] = S, beta[m:k+l] = 1. +// +// if k+l < n, +// +// alpha[k+l:n] = 0 and +// beta[k+l:n] = 0. +// +// On exit, iwork contains the permutation required to sort alpha descending. +// +// iwork must have length n, work must have length at least max(1, lwork), and +// lwork must be -1 or greater than n, otherwise Ggsvd3 will panic. If +// lwork is -1, work[0] holds the optimal lwork on return, but Ggsvd3 does +// not perform the GSVD. +func Ggsvd3(jobU, jobV, jobQ lapack.GSVDJob, a, b blas64.General, alpha, beta []float64, u, v, q blas64.General, work []float64, lwork int, iwork []int) (k, l int, ok bool) { + return lapack64.Dggsvd3(jobU, jobV, jobQ, a.Rows, a.Cols, b.Rows, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride), alpha, beta, u.Data, max(1, u.Stride), v.Data, max(1, v.Stride), q.Data, max(1, q.Stride), work, lwork, iwork) +} + +// Gtsv solves one of the equations +// +// A * X = B if trans == blas.NoTrans +// Aᵀ * X = B if trans == blas.Trans or blas.ConjTrans +// +// where A is an n×n tridiagonal matrix. It uses Gaussian elimination with +// partial pivoting. +// +// On entry, a contains the matrix A, on return it will be overwritten. +// +// On entry, b contains the n×nrhs right-hand side matrix B. On return, it will +// be overwritten. If ok is true, it will be overwritten by the solution matrix X. +// +// Gtsv returns whether the solution X has been successfully computed. +// +// Dgtsv is not part of the lapack.Float64 interface and so calls to Gtsv are +// always executed by the Gonum implementation. +func Gtsv(trans blas.Transpose, a Tridiagonal, b blas64.General) (ok bool) { + if trans != blas.NoTrans { + a.DL, a.DU = a.DU, a.DL + } + return gonum.Implementation{}.Dgtsv(a.N, b.Cols, a.DL, a.D, a.DU, b.Data, max(1, b.Stride)) +} + +// Lagtm performs one of the matrix-matrix operations +// +// C = alpha * A * B + beta * C if trans == blas.NoTrans +// C = alpha * Aᵀ * B + beta * C if trans == blas.Trans or blas.ConjTrans +// +// where A is an m×m tridiagonal matrix represented by its diagonals dl, d, du, +// B and C are m×n dense matrices, and alpha and beta are scalars. +// +// Dlagtm is not part of the lapack.Float64 interface and so calls to Lagtm are +// always executed by the Gonum implementation. +func Lagtm(trans blas.Transpose, alpha float64, a Tridiagonal, b blas64.General, beta float64, c blas64.General) { + gonum.Implementation{}.Dlagtm(trans, c.Rows, c.Cols, alpha, a.DL, a.D, a.DU, b.Data, max(1, b.Stride), beta, c.Data, max(1, c.Stride)) +} + +// Lange computes the matrix norm of the general m×n matrix A. The input norm +// specifies the norm computed. +// +// lapack.MaxAbs: the maximum absolute value of an element. +// lapack.MaxColumnSum: the maximum column sum of the absolute values of the entries. +// lapack.MaxRowSum: the maximum row sum of the absolute values of the entries. +// lapack.Frobenius: the square root of the sum of the squares of the entries. +// +// If norm == lapack.MaxColumnSum, work must be of length n, and this function will panic otherwise. +// There are no restrictions on work for the other matrix norms. +func Lange(norm lapack.MatrixNorm, a blas64.General, work []float64) float64 { + return lapack64.Dlange(norm, a.Rows, a.Cols, a.Data, max(1, a.Stride), work) +} + +// Langb returns the given norm of a general m×n band matrix with kl sub-diagonals and +// ku super-diagonals. +// +// Dlangb is not part of the lapack.Float64 interface and so calls to Langb are always +// executed by the Gonum implementation. +func Langb(norm lapack.MatrixNorm, a blas64.Band) float64 { + return gonum.Implementation{}.Dlangb(norm, a.Rows, a.Cols, a.KL, a.KU, a.Data, max(1, a.Stride)) +} + +// Langt computes the specified norm of an n×n tridiagonal matrix. +// +// Dlangt is not part of the lapack.Float64 interface and so calls to Langt are +// always executed by the Gonum implementation. +func Langt(norm lapack.MatrixNorm, a Tridiagonal) float64 { + return gonum.Implementation{}.Dlangt(norm, a.N, a.DL, a.D, a.DU) +} + +// Lansb computes the specified norm of an n×n symmetric band matrix. If +// norm == lapack.MaxColumnSum or norm == lapack.MaxRowSum, work must have length +// at least n and this function will panic otherwise. +// There are no restrictions on work for the other matrix norms. +// +// Dlansb is not part of the lapack.Float64 interface and so calls to Lansb are always +// executed by the Gonum implementation. +func Lansb(norm lapack.MatrixNorm, a blas64.SymmetricBand, work []float64) float64 { + return gonum.Implementation{}.Dlansb(norm, a.Uplo, a.N, a.K, a.Data, max(1, a.Stride), work) +} + +// Lansy computes the specified norm of an n×n symmetric matrix. If +// norm == lapack.MaxColumnSum or norm == lapack.MaxRowSum, work must have length +// at least n and this function will panic otherwise. +// There are no restrictions on work for the other matrix norms. +func Lansy(norm lapack.MatrixNorm, a blas64.Symmetric, work []float64) float64 { + return lapack64.Dlansy(norm, a.Uplo, a.N, a.Data, max(1, a.Stride), work) +} + +// Lantr computes the specified norm of an m×n trapezoidal matrix A. If +// norm == lapack.MaxColumnSum work must have length at least n and this function +// will panic otherwise. There are no restrictions on work for the other matrix norms. +func Lantr(norm lapack.MatrixNorm, a blas64.Triangular, work []float64) float64 { + return lapack64.Dlantr(norm, a.Uplo, a.Diag, a.N, a.N, a.Data, max(1, a.Stride), work) +} + +// Lantb computes the specified norm of an n×n triangular band matrix A. If +// norm == lapack.MaxColumnSum work must have length at least n and this function +// will panic otherwise. There are no restrictions on work for the other matrix +// norms. +func Lantb(norm lapack.MatrixNorm, a blas64.TriangularBand, work []float64) float64 { + return gonum.Implementation{}.Dlantb(norm, a.Uplo, a.Diag, a.N, a.K, a.Data, max(1, a.Stride), work) +} + +// Lapmr rearranges the rows of the m×n matrix X as specified by the permutation +// k[0],k[1],...,k[m-1] of the integers 0,...,m-1. +// +// If forward is true, a forward permutation is applied: +// +// X[k[i],0:n] is moved to X[i,0:n] for i=0,1,...,m-1. +// +// If forward is false, a backward permutation is applied: +// +// X[i,0:n] is moved to X[k[i],0:n] for i=0,1,...,m-1. +// +// k must have length m, otherwise Lapmr will panic. +func Lapmr(forward bool, x blas64.General, k []int) { + lapack64.Dlapmr(forward, x.Rows, x.Cols, x.Data, max(1, x.Stride), k) +} + +// Lapmt rearranges the columns of the m×n matrix X as specified by the +// permutation k_0, k_1, ..., k_{n-1} of the integers 0, ..., n-1. +// +// If forward is true a forward permutation is performed: +// +// X[0:m, k[j]] is moved to X[0:m, j] for j = 0, 1, ..., n-1. +// +// otherwise a backward permutation is performed: +// +// X[0:m, j] is moved to X[0:m, k[j]] for j = 0, 1, ..., n-1. +// +// k must have length n, otherwise Lapmt will panic. k is zero-indexed. +func Lapmt(forward bool, x blas64.General, k []int) { + lapack64.Dlapmt(forward, x.Rows, x.Cols, x.Data, max(1, x.Stride), k) +} + +// Ormlq multiplies the matrix C by the othogonal matrix Q defined by +// A and tau. A and tau are as returned from Gelqf. +// +// C = Q * C if side == blas.Left and trans == blas.NoTrans +// C = Qᵀ * C if side == blas.Left and trans == blas.Trans +// C = C * Q if side == blas.Right and trans == blas.NoTrans +// C = C * Qᵀ if side == blas.Right and trans == blas.Trans +// +// If side == blas.Left, A is a matrix of side k×m, and if side == blas.Right +// A is of size k×n. This uses a blocked algorithm. +// +// Work is temporary storage, and lwork specifies the usable memory length. +// At minimum, lwork >= m if side == blas.Left and lwork >= n if side == blas.Right, +// and this function will panic otherwise. +// Ormlq uses a block algorithm, but the block size is limited +// by the temporary space available. If lwork == -1, instead of performing Ormlq, +// the optimal work length will be stored into work[0]. +// +// Tau contains the Householder scales and must have length at least k, and +// this function will panic otherwise. +func Ormlq(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) { + lapack64.Dormlq(side, trans, c.Rows, c.Cols, a.Rows, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork) +} + +// Ormqr multiplies an m×n matrix C by an orthogonal matrix Q as +// +// C = Q * C if side == blas.Left and trans == blas.NoTrans, +// C = Qᵀ * C if side == blas.Left and trans == blas.Trans, +// C = C * Q if side == blas.Right and trans == blas.NoTrans, +// C = C * Qᵀ if side == blas.Right and trans == blas.Trans, +// +// where Q is defined as the product of k elementary reflectors +// +// Q = H_0 * H_1 * ... * H_{k-1}. +// +// If side == blas.Left, A is an m×k matrix and 0 <= k <= m. +// If side == blas.Right, A is an n×k matrix and 0 <= k <= n. +// The ith column of A contains the vector which defines the elementary +// reflector H_i and tau[i] contains its scalar factor. tau must have length k +// and Ormqr will panic otherwise. Geqrf returns A and tau in the required +// form. +// +// work must have length at least max(1,lwork), and lwork must be at least n if +// side == blas.Left and at least m if side == blas.Right, otherwise Ormqr will +// panic. +// +// work is temporary storage, and lwork specifies the usable memory length. At +// minimum, lwork >= m if side == blas.Left and lwork >= n if side == +// blas.Right, and this function will panic otherwise. Larger values of lwork +// will generally give better performance. On return, work[0] will contain the +// optimal value of lwork. +// +// If lwork is -1, instead of performing Ormqr, the optimal workspace size will +// be stored into work[0]. +func Ormqr(side blas.Side, trans blas.Transpose, a blas64.General, tau []float64, c blas64.General, work []float64, lwork int) { + lapack64.Dormqr(side, trans, c.Rows, c.Cols, a.Cols, a.Data, max(1, a.Stride), tau, c.Data, max(1, c.Stride), work, lwork) +} + +// Pocon estimates the reciprocal of the condition number of a positive-definite +// matrix A given the Cholesky decomposition of A. The condition number computed +// is based on the 1-norm and the ∞-norm. +// +// anorm is the 1-norm and the ∞-norm of the original matrix A. +// +// work is a temporary data slice of length at least 3*n and Pocon will panic otherwise. +// +// iwork is a temporary data slice of length at least n and Pocon will panic otherwise. +func Pocon(a blas64.Symmetric, anorm float64, work []float64, iwork []int) float64 { + return lapack64.Dpocon(a.Uplo, a.N, a.Data, max(1, a.Stride), anorm, work, iwork) +} + +// Syev computes all eigenvalues and, optionally, the eigenvectors of a real +// symmetric matrix A. +// +// w contains the eigenvalues in ascending order upon return. w must have length +// at least n, and Syev will panic otherwise. +// +// On entry, a contains the elements of the symmetric matrix A in the triangular +// portion specified by uplo. If jobz == lapack.EVCompute, a contains the +// orthonormal eigenvectors of A on exit, otherwise jobz must be lapack.EVNone +// and on exit the specified triangular region is overwritten. +// +// Work is temporary storage, and lwork specifies the usable memory length. At minimum, +// lwork >= 3*n-1, and Syev will panic otherwise. The amount of blocking is +// limited by the usable length. If lwork == -1, instead of computing Syev the +// optimal work length is stored into work[0]. +func Syev(jobz lapack.EVJob, a blas64.Symmetric, w, work []float64, lwork int) (ok bool) { + return lapack64.Dsyev(jobz, a.Uplo, a.N, a.Data, max(1, a.Stride), w, work, lwork) +} + +// Tbtrs solves a triangular system of the form +// +// A * X = B if trans == blas.NoTrans +// Aᵀ * X = B if trans == blas.Trans or blas.ConjTrans +// +// where A is an n×n triangular band matrix, and B is an n×nrhs matrix. +// +// Tbtrs returns whether A is non-singular. If A is singular, no solutions X +// are computed. +func Tbtrs(trans blas.Transpose, a blas64.TriangularBand, b blas64.General) (ok bool) { + return lapack64.Dtbtrs(a.Uplo, trans, a.Diag, a.N, a.K, b.Cols, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride)) +} + +// Trcon estimates the reciprocal of the condition number of a triangular matrix A. +// The condition number computed may be based on the 1-norm or the ∞-norm. +// +// work is a temporary data slice of length at least 3*n and Trcon will panic otherwise. +// +// iwork is a temporary data slice of length at least n and Trcon will panic otherwise. +func Trcon(norm lapack.MatrixNorm, a blas64.Triangular, work []float64, iwork []int) float64 { + return lapack64.Dtrcon(norm, a.Uplo, a.Diag, a.N, a.Data, max(1, a.Stride), work, iwork) +} + +// Trtri computes the inverse of a triangular matrix, storing the result in place +// into a. +// +// Trtri will not perform the inversion if the matrix is singular, and returns +// a boolean indicating whether the inversion was successful. +func Trtri(a blas64.Triangular) (ok bool) { + return lapack64.Dtrtri(a.Uplo, a.Diag, a.N, a.Data, max(1, a.Stride)) +} + +// Trtrs solves a triangular system of the form A * X = B or Aᵀ * X = B. Trtrs +// returns whether the solve completed successfully. If A is singular, no solve is performed. +func Trtrs(trans blas.Transpose, a blas64.Triangular, b blas64.General) (ok bool) { + return lapack64.Dtrtrs(a.Uplo, trans, a.Diag, a.N, b.Cols, a.Data, max(1, a.Stride), b.Data, max(1, b.Stride)) +} + +// Geev computes the eigenvalues and, optionally, the left and/or right +// eigenvectors for an n×n real nonsymmetric matrix A. +// +// The right eigenvector v_j of A corresponding to an eigenvalue λ_j +// is defined by +// +// A v_j = λ_j v_j, +// +// and the left eigenvector u_j corresponding to an eigenvalue λ_j is defined by +// +// u_jᴴ A = λ_j u_jᴴ, +// +// where u_jᴴ is the conjugate transpose of u_j. +// +// On return, A will be overwritten and the left and right eigenvectors will be +// stored, respectively, in the columns of the n×n matrices VL and VR in the +// same order as their eigenvalues. If the j-th eigenvalue is real, then +// +// u_j = VL[:,j], +// v_j = VR[:,j], +// +// and if it is not real, then j and j+1 form a complex conjugate pair and the +// eigenvectors can be recovered as +// +// u_j = VL[:,j] + i*VL[:,j+1], +// u_{j+1} = VL[:,j] - i*VL[:,j+1], +// v_j = VR[:,j] + i*VR[:,j+1], +// v_{j+1} = VR[:,j] - i*VR[:,j+1], +// +// where i is the imaginary unit. The computed eigenvectors are normalized to +// have Euclidean norm equal to 1 and largest component real. +// +// Left eigenvectors will be computed only if jobvl == lapack.LeftEVCompute, +// otherwise jobvl must be lapack.LeftEVNone. +// Right eigenvectors will be computed only if jobvr == lapack.RightEVCompute, +// otherwise jobvr must be lapack.RightEVNone. +// For other values of jobvl and jobvr Geev will panic. +// +// On return, wr and wi will contain the real and imaginary parts, respectively, +// of the computed eigenvalues. Complex conjugate pairs of eigenvalues appear +// consecutively with the eigenvalue having the positive imaginary part first. +// wr and wi must have length n, and Geev will panic otherwise. +// +// work must have length at least lwork and lwork must be at least max(1,4*n) if +// the left or right eigenvectors are computed, and at least max(1,3*n) if no +// eigenvectors are computed. For good performance, lwork must generally be +// larger. On return, optimal value of lwork will be stored in work[0]. +// +// If lwork == -1, instead of performing Geev, the function only calculates the +// optimal vaule of lwork and stores it into work[0]. +// +// On return, first will be the index of the first valid eigenvalue. +// If first == 0, all eigenvalues and eigenvectors have been computed. +// If first is positive, Geev failed to compute all the eigenvalues, no +// eigenvectors have been computed and wr[first:] and wi[first:] contain those +// eigenvalues which have converged. +func Geev(jobvl lapack.LeftEVJob, jobvr lapack.RightEVJob, a blas64.General, wr, wi []float64, vl, vr blas64.General, work []float64, lwork int) (first int) { + n := a.Rows + if a.Cols != n { + panic("lapack64: matrix not square") + } + if jobvl == lapack.LeftEVCompute && (vl.Rows != n || vl.Cols != n) { + panic("lapack64: bad size of VL") + } + if jobvr == lapack.RightEVCompute && (vr.Rows != n || vr.Cols != n) { + panic("lapack64: bad size of VR") + } + return lapack64.Dgeev(jobvl, jobvr, n, a.Data, max(1, a.Stride), wr, wi, vl.Data, max(1, vl.Stride), vr.Data, max(1, vr.Stride), work, lwork) +} diff --git a/vendor/gonum.org/v1/gonum/mat/cholesky.go b/vendor/gonum.org/v1/gonum/mat/cholesky.go new file mode 100644 index 0000000000..0f957cdde9 --- /dev/null +++ b/vendor/gonum.org/v1/gonum/mat/cholesky.go @@ -0,0 +1,938 @@ +// Copyright ©2013 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mat + +import ( + "math" + + "gonum.org/v1/gonum/blas" + "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/lapack/lapack64" +) + +const ( + badTriangle = "mat: invalid triangle" + badCholesky = "mat: invalid Cholesky factorization" +) + +var ( + _ Matrix = (*Cholesky)(nil) + _ Symmetric = (*Cholesky)(nil) + + _ Matrix = (*BandCholesky)(nil) + _ Symmetric = (*BandCholesky)(nil) + _ Banded = (*BandCholesky)(nil) + _ SymBanded = (*BandCholesky)(nil) +) + +// Cholesky is a symmetric positive definite matrix represented by its +// Cholesky decomposition. +// +// The decomposition can be constructed using the Factorize method. The +// factorization itself can be extracted using the UTo or LTo methods, and the +// original symmetric matrix can be recovered with ToSym. +// +// Note that this matrix representation is useful for certain operations, in +// particular finding solutions to linear equations. It is very inefficient +// at other operations, in particular At is slow. +// +// Cholesky methods may only be called on a value that has been successfully +// initialized by a call to Factorize that has returned true. Calls to methods +// of an unsuccessful Cholesky factorization will panic. +type Cholesky struct { + // The chol pointer must never be retained as a pointer outside the Cholesky + // struct, either by returning chol outside the struct or by setting it to + // a pointer coming from outside. The same prohibition applies to the data + // slice within chol. + chol *TriDense + cond float64 +} + +// updateCond updates the condition number of the Cholesky decomposition. If +// norm > 0, then that norm is used as the norm of the original matrix A, otherwise +// the norm is estimated from the decomposition. +func (c *Cholesky) updateCond(norm float64) { + n := c.chol.mat.N + work := getFloat64s(3*n, false) + defer putFloat64s(work) + if norm < 0 { + // This is an approximation. By the definition of a norm, + // |AB| <= |A| |B|. + // Since A = Uᵀ*U, we get for the condition number κ that + // κ(A) := |A| |A^-1| = |Uᵀ*U| |A^-1| <= |Uᵀ| |U| |A^-1|, + // so this will overestimate the condition number somewhat. + // The norm of the original factorized matrix cannot be stored + // because of update possibilities. + unorm := lapack64.Lantr(CondNorm, c.chol.mat, work) + lnorm := lapack64.Lantr(CondNormTrans, c.chol.mat, work) + norm = unorm * lnorm + } + sym := c.chol.asSymBlas() + iwork := getInts(n, false) + v := lapack64.Pocon(sym, norm, work, iwork) + putInts(iwork) + c.cond = 1 / v +} + +// Dims returns the dimensions of the matrix. +func (ch *Cholesky) Dims() (r, c int) { + if !ch.valid() { + panic(badCholesky) + } + r, c = ch.chol.Dims() + return r, c +} + +// At returns the element at row i, column j. +func (c *Cholesky) At(i, j int) float64 { + if !c.valid() { + panic(badCholesky) + } + n := c.SymmetricDim() + if uint(i) >= uint(n) { + panic(ErrRowAccess) + } + if uint(j) >= uint(n) { + panic(ErrColAccess) + } + + var val float64 + for k := 0; k <= min(i, j); k++ { + val += c.chol.at(k, i) * c.chol.at(k, j) + } + return val +} + +// T returns the receiver, the transpose of a symmetric matrix. +func (c *Cholesky) T() Matrix { + return c +} + +// SymmetricDim implements the Symmetric interface and returns the number of rows +// in the matrix (this is also the number of columns). +func (c *Cholesky) SymmetricDim() int { + r, _ := c.chol.Dims() + return r +} + +// Cond returns the condition number of the factorized matrix. +func (c *Cholesky) Cond() float64 { + if !c.valid() { + panic(badCholesky) + } + return c.cond +} + +// Factorize calculates the Cholesky decomposition of the matrix A and returns +// whether the matrix is positive definite. If Factorize returns false, the +// factorization must not be used. +func (c *Cholesky) Factorize(a Symmetric) (ok bool) { + n := a.SymmetricDim() + if c.chol == nil { + c.chol = NewTriDense(n, Upper, nil) + } else { + c.chol.Reset() + c.chol.reuseAsNonZeroed(n, Upper) + } + copySymIntoTriangle(c.chol, a) + + sym := c.chol.asSymBlas() + work := getFloat64s(c.chol.mat.N, false) + norm := lapack64.Lansy(CondNorm, sym, work) + putFloat64s(work) + _, ok = lapack64.Potrf(sym) + if ok { + c.updateCond(norm) + } else { + c.Reset() + } + return ok +} + +// Reset resets the factorization so that it can be reused as the receiver of a +// dimensionally restricted operation. +func (c *Cholesky) Reset() { + if c.chol != nil { + c.chol.Reset() + } + c.cond = math.Inf(1) +} + +// IsEmpty returns whether the receiver is empty. Empty matrices can be the +// receiver for size-restricted operations. The receiver can be emptied using +// Reset. +func (c *Cholesky) IsEmpty() bool { + return c.chol == nil || c.chol.IsEmpty() +} + +// SetFromU sets the Cholesky decomposition from the given triangular matrix. +// SetFromU panics if t is not upper triangular. If the receiver is empty it +// is resized to be n×n, the size of t. If dst is non-empty, SetFromU panics +// if c is not of size n×n. Note that t is copied into, not stored inside, the +// receiver. +func (c *Cholesky) SetFromU(t Triangular) { + n, kind := t.Triangle() + if kind != Upper { + panic("cholesky: matrix must be upper triangular") + } + if c.chol == nil { + c.chol = NewTriDense(n, Upper, nil) + } else { + c.chol.reuseAsNonZeroed(n, Upper) + } + c.chol.Copy(t) + c.updateCond(-1) +} + +// Clone makes a copy of the input Cholesky into the receiver, overwriting the +// previous value of the receiver. Clone does not place any restrictions on receiver +// shape. Clone panics if the input Cholesky is not the result of a valid decomposition. +func (c *Cholesky) Clone(chol *Cholesky) { + if !chol.valid() { + panic(badCholesky) + } + n := chol.SymmetricDim() + if c.chol == nil { + c.chol = NewTriDense(n, Upper, nil) + } else { + c.chol = NewTriDense(n, Upper, use(c.chol.mat.Data, n*n)) + } + c.chol.Copy(chol.chol) + c.cond = chol.cond +} + +// Det returns the determinant of the matrix that has been factorized. +func (c *Cholesky) Det() float64 { + if !c.valid() { + panic(badCholesky) + } + return math.Exp(c.LogDet()) +} + +// LogDet returns the log of the determinant of the matrix that has been factorized. +func (c *Cholesky) LogDet() float64 { + if !c.valid() { + panic(badCholesky) + } + var det float64 + for i := 0; i < c.chol.mat.N; i++ { + det += 2 * math.Log(c.chol.mat.Data[i*c.chol.mat.Stride+i]) + } + return det +} + +// SolveTo finds the matrix X that solves A * X = B where A is represented +// by the Cholesky decomposition. The result is stored in-place into dst. +// If the Cholesky decomposition is singular or near-singular a Condition error +// is returned. See the documentation for Condition for more information. +func (c *Cholesky) SolveTo(dst *Dense, b Matrix) error { + if !c.valid() { + panic(badCholesky) + } + n := c.chol.mat.N + bm, bn := b.Dims() + if n != bm { + panic(ErrShape) + } + + dst.reuseAsNonZeroed(bm, bn) + if b != dst { + dst.Copy(b) + } + lapack64.Potrs(c.chol.mat, dst.mat) + if c.cond > ConditionTolerance { + return Condition(c.cond) + } + return nil +} + +// SolveCholTo finds the matrix X that solves A * X = B where A and B are represented +// by their Cholesky decompositions a and b. The result is stored in-place into +// dst. +// If the Cholesky decomposition is singular or near-singular a Condition error +// is returned. See the documentation for Condition for more information. +func (a *Cholesky) SolveCholTo(dst *Dense, b *Cholesky) error { + if !a.valid() || !b.valid() { + panic(badCholesky) + } + bn := b.chol.mat.N + if a.chol.mat.N != bn { + panic(ErrShape) + } + + dst.reuseAsZeroed(bn, bn) + dst.Copy(b.chol.T()) + blas64.Trsm(blas.Left, blas.Trans, 1, a.chol.mat, dst.mat) + blas64.Trsm(blas.Left, blas.NoTrans, 1, a.chol.mat, dst.mat) + blas64.Trmm(blas.Right, blas.NoTrans, 1, b.chol.mat, dst.mat) + if a.cond > ConditionTolerance { + return Condition(a.cond) + } + return nil +} + +// SolveVecTo finds the vector x that solves A * x = b where A is represented +// by the Cholesky decomposition. The result is stored in-place into +// dst. +// If the Cholesky decomposition is singular or near-singular a Condition error +// is returned. See the documentation for Condition for more information. +func (c *Cholesky) SolveVecTo(dst *VecDense, b Vector) error { + if !c.valid() { + panic(badCholesky) + } + n := c.chol.mat.N + if br, bc := b.Dims(); br != n || bc != 1 { + panic(ErrShape) + } + switch rv := b.(type) { + default: + dst.reuseAsNonZeroed(n) + return c.SolveTo(dst.asDense(), b) + case RawVectorer: + bmat := rv.RawVector() + if dst != b { + dst.checkOverlap(bmat) + } + dst.reuseAsNonZeroed(n) + if dst != b { + dst.CopyVec(b) + } + lapack64.Potrs(c.chol.mat, dst.asGeneral()) + if c.cond > ConditionTolerance { + return Condition(c.cond) + } + return nil + } +} + +// RawU returns the Triangular matrix used to store the Cholesky decomposition of +// the original matrix A. The returned matrix should not be modified. If it is +// modified, the decomposition is invalid and should not be used. +func (c *Cholesky) RawU() Triangular { + return c.chol +} + +// UTo stores into dst the n×n upper triangular matrix U from a Cholesky +// decomposition +// +// A = Uᵀ * U. +// +// If dst is empty, it is resized to be an n×n upper triangular matrix. When dst +// is non-empty, UTo panics if dst is not n×n or not Upper. UTo will also panic +// if the receiver does not contain a successful factorization. +func (c *Cholesky) UTo(dst *TriDense) { + if !c.valid() { + panic(badCholesky) + } + n := c.chol.mat.N + if dst.IsEmpty() { + dst.ReuseAsTri(n, Upper) + } else { + n2, kind := dst.Triangle() + if n != n2 { + panic(ErrShape) + } + if kind != Upper { + panic(ErrTriangle) + } + } + dst.Copy(c.chol) +} + +// LTo stores into dst the n×n lower triangular matrix L from a Cholesky +// decomposition +// +// A = L * Lᵀ. +// +// If dst is empty, it is resized to be an n×n lower triangular matrix. When dst +// is non-empty, LTo panics if dst is not n×n or not Lower. LTo will also panic +// if the receiver does not contain a successful factorization. +func (c *Cholesky) LTo(dst *TriDense) { + if !c.valid() { + panic(badCholesky) + } + n := c.chol.mat.N + if dst.IsEmpty() { + dst.ReuseAsTri(n, Lower) + } else { + n2, kind := dst.Triangle() + if n != n2 { + panic(ErrShape) + } + if kind != Lower { + panic(ErrTriangle) + } + } + dst.Copy(c.chol.TTri()) +} + +// ToSym reconstructs the original positive definite matrix from its +// Cholesky decomposition, storing the result into dst. If dst is +// empty it is resized to be n×n. If dst is non-empty, ToSym panics +// if dst is not of size n×n. ToSym will also panic if the receiver +// does not contain a successful factorization. +func (c *Cholesky) ToSym(dst *SymDense) { + if !c.valid() { + panic(badCholesky) + } + n := c.chol.mat.N + if dst.IsEmpty() { + dst.ReuseAsSym(n) + } else { + n2 := dst.SymmetricDim() + if n != n2 { + panic(ErrShape) + } + } + // Create a TriDense representing the Cholesky factor U with dst's + // backing slice. + // Operations on u are reflected in s. + u := &TriDense{ + mat: blas64.Triangular{ + Uplo: blas.Upper, + Diag: blas.NonUnit, + N: n, + Data: dst.mat.Data, + Stride: dst.mat.Stride, + }, + cap: n, + } + u.Copy(c.chol) + // Compute the product Uᵀ*U using the algorithm from LAPACK/TESTING/LIN/dpot01.f + a := u.mat.Data + lda := u.mat.Stride + bi := blas64.Implementation() + for k := n - 1; k >= 0; k-- { + a[k*lda+k] = bi.Ddot(k+1, a[k:], lda, a[k:], lda) + if k > 0 { + bi.Dtrmv(blas.Upper, blas.Trans, blas.NonUnit, k, a, lda, a[k:], lda) + } + } +} + +// InverseTo computes the inverse of the matrix represented by its Cholesky +// factorization and stores the result into s. If the factorized +// matrix is ill-conditioned, a Condition error will be returned. +// Note that matrix inversion is numerically unstable, and should generally be +// avoided where possible, for example by using the Solve routines. +func (c *Cholesky) InverseTo(dst *SymDense) error { + if !c.valid() { + panic(badCholesky) + } + dst.reuseAsNonZeroed(c.chol.mat.N) + // Create a TriDense representing the Cholesky factor U with the backing + // slice from dst. + // Operations on u are reflected in dst. + u := &TriDense{ + mat: blas64.Triangular{ + Uplo: blas.Upper, + Diag: blas.NonUnit, + N: dst.mat.N, + Data: dst.mat.Data, + Stride: dst.mat.Stride, + }, + cap: dst.mat.N, + } + u.Copy(c.chol) + + _, ok := lapack64.Potri(u.mat) + if !ok { + return Condition(math.Inf(1)) + } + if c.cond > ConditionTolerance { + return Condition(c.cond) + } + return nil +} + +// Scale multiplies the original matrix A by a positive constant using +// its Cholesky decomposition, storing the result in-place into the receiver. +// That is, if the original Cholesky factorization is +// +// Uᵀ * U = A +// +// the updated factorization is +// +// U'ᵀ * U' = f A = A' +// +// Scale panics if the constant is non-positive, or if the receiver is non-empty +// and is of a different size from the input. +func (c *Cholesky) Scale(f float64, orig *Cholesky) { + if !orig.valid() { + panic(badCholesky) + } + if f <= 0 { + panic("cholesky: scaling by a non-positive constant") + } + n := orig.SymmetricDim() + if c.chol == nil { + c.chol = NewTriDense(n, Upper, nil) + } else if c.chol.mat.N != n { + panic(ErrShape) + } + c.chol.ScaleTri(math.Sqrt(f), orig.chol) + c.cond = orig.cond // Scaling by a positive constant does not change the condition number. +} + +// ExtendVecSym computes the Cholesky decomposition of the original matrix A, +// whose Cholesky decomposition is in a, extended by a the n×1 vector v according to +// +// [A w] +// [w' k] +// +// where k = v[n-1] and w = v[:n-1]. The result is stored into the receiver. +// In order for the updated matrix to be positive definite, it must be the case +// that k > w' A^-1 w. If this condition does not hold then ExtendVecSym will +// return false and the receiver will not be updated. +// +// ExtendVecSym will panic if v.Len() != a.SymmetricDim()+1 or if a does not contain +// a valid decomposition. +func (c *Cholesky) ExtendVecSym(a *Cholesky, v Vector) (ok bool) { + n := a.SymmetricDim() + + if v.Len() != n+1 { + panic(badSliceLength) + } + if !a.valid() { + panic(badCholesky) + } + + // The algorithm is commented here, but see also + // https://math.stackexchange.com/questions/955874/cholesky-factor-when-adding-a-row-and-column-to-already-factorized-matrix + // We have A and want to compute the Cholesky of + // [A w] + // [w' k] + // We want + // [U c] + // [0 d] + // to be the updated Cholesky, and so it must be that + // [A w] = [U' 0] [U c] + // [w' k] [c' d] [0 d] + // Thus, we need + // 1) A = U'U (true by the original decomposition being valid), + // 2) U' * c = w => c = U'^-1 w + // 3) c'*c + d'*d = k => d = sqrt(k-c'*c) + + // First, compute c = U'^-1 a + w := NewVecDense(n, nil) + w.CopyVec(v) + k := v.At(n, 0) + + var t VecDense + _ = t.SolveVec(a.chol.T(), w) + + dot := Dot(&t, &t) + if dot >= k { + return false + } + d := math.Sqrt(k - dot) + + newU := NewTriDense(n+1, Upper, nil) + newU.Copy(a.chol) + for i := 0; i < n; i++ { + newU.SetTri(i, n, t.At(i, 0)) + } + newU.SetTri(n, n, d) + c.chol = newU + c.updateCond(-1) + return true +} + +// SymRankOne performs a rank-1 update of the original matrix A and refactorizes +// its Cholesky factorization, storing the result into the receiver. That is, if +// in the original Cholesky factorization +// +// Uᵀ * U = A, +// +// in the updated factorization +// +// U'ᵀ * U' = A + alpha * x * xᵀ = A'. +// +// Note that when alpha is negative, the updating problem may be ill-conditioned +// and the results may be inaccurate, or the updated matrix A' may not be +// positive definite and not have a Cholesky factorization. SymRankOne returns +// whether the updated matrix A' is positive definite. If the update fails +// the receiver is left unchanged. +// +// SymRankOne updates a Cholesky factorization in O(n²) time. The Cholesky +// factorization computation from scratch is O(n³). +func (c *Cholesky) SymRankOne(orig *Cholesky, alpha float64, x Vector) (ok bool) { + if !orig.valid() { + panic(badCholesky) + } + n := orig.SymmetricDim() + if r, c := x.Dims(); r != n || c != 1 { + panic(ErrShape) + } + if orig != c { + if c.chol == nil { + c.chol = NewTriDense(n, Upper, nil) + } else if c.chol.mat.N != n { + panic(ErrShape) + } + c.chol.Copy(orig.chol) + } + + if alpha == 0 { + return true + } + + // Algorithms for updating and downdating the Cholesky factorization are + // described, for example, in + // - J. J. Dongarra, J. R. Bunch, C. B. Moler, G. W. Stewart: LINPACK + // Users' Guide. SIAM (1979), pages 10.10--10.14 + // or + // - P. E. Gill, G. H. Golub, W. Murray, and M. A. Saunders: Methods for + // modifying matrix factorizations. Mathematics of Computation 28(126) + // (1974), Method C3 on page 521 + // + // The implementation is based on LINPACK code + // http://www.netlib.org/linpack/dchud.f + // http://www.netlib.org/linpack/dchdd.f + // and + // https://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=2&t=2646 + // + // According to http://icl.cs.utk.edu/lapack-forum/archives/lapack/msg00301.html + // LINPACK is released under BSD license. + // + // See also: + // - M. A. Saunders: Large-scale Linear Programming Using the Cholesky + // Factorization. Technical Report Stanford University (1972) + // http://i.stanford.edu/pub/cstr/reports/cs/tr/72/252/CS-TR-72-252.pdf + // - Matthias Seeger: Low rank updates for the Cholesky decomposition. + // EPFL Technical Report 161468 (2004) + // http://infoscience.epfl.ch/record/161468 + + work := getFloat64s(n, false) + defer putFloat64s(work) + var xmat blas64.Vector + if rv, ok := x.(RawVectorer); ok { + xmat = rv.RawVector() + } else { + var tmp *VecDense + tmp.CopyVec(x) + xmat = tmp.RawVector() + } + blas64.Copy(xmat, blas64.Vector{N: n, Data: work, Inc: 1}) + + if alpha > 0 { + // Compute rank-1 update. + if alpha != 1 { + blas64.Scal(math.Sqrt(alpha), blas64.Vector{N: n, Data: work, Inc: 1}) + } + umat := c.chol.mat + stride := umat.Stride + for i := 0; i < n; i++ { + // Compute parameters of the Givens matrix that zeroes + // the i-th element of x. + c, s, r, _ := blas64.Rotg(umat.Data[i*stride+i], work[i]) + if r < 0 { + // Multiply by -1 to have positive diagonal + // elements. + r *= -1 + c *= -1 + s *= -1 + } + umat.Data[i*stride+i] = r + if i < n-1 { + // Multiply the extended factorization matrix by + // the Givens matrix from the left. Only + // the i-th row and x are modified. + blas64.Rot( + blas64.Vector{N: n - i - 1, Data: umat.Data[i*stride+i+1 : i*stride+n], Inc: 1}, + blas64.Vector{N: n - i - 1, Data: work[i+1 : n], Inc: 1}, + c, s) + } + } + c.updateCond(-1) + return true + } + + // Compute rank-1 downdate. + alpha = math.Sqrt(-alpha) + if alpha != 1 { + blas64.Scal(alpha, blas64.Vector{N: n, Data: work, Inc: 1}) + } + // Solve Uᵀ * p = x storing the result into work. + ok = lapack64.Trtrs(blas.Trans, c.chol.RawTriangular(), blas64.General{ + Rows: n, + Cols: 1, + Stride: 1, + Data: work, + }) + if !ok { + // The original matrix is singular. Should not happen, because + // the factorization is valid. + panic(badCholesky) + } + norm := blas64.Nrm2(blas64.Vector{N: n, Data: work, Inc: 1}) + if norm >= 1 { + // The updated matrix is not positive definite. + return false + } + norm = math.Sqrt((1 + norm) * (1 - norm)) + cos := getFloat64s(n, false) + defer putFloat64s(cos) + sin := getFloat64s(n, false) + defer putFloat64s(sin) + for i := n - 1; i >= 0; i-- { + // Compute parameters of Givens matrices that zero elements of p + // backwards. + cos[i], sin[i], norm, _ = blas64.Rotg(norm, work[i]) + if norm < 0 { + norm *= -1 + cos[i] *= -1 + sin[i] *= -1 + } + } + workMat := getTriDenseWorkspace(c.chol.mat.N, c.chol.triKind(), false) + defer putTriWorkspace(workMat) + workMat.Copy(c.chol) + umat := workMat.mat + stride := workMat.mat.Stride + for i := n - 1; i >= 0; i-- { + work[i] = 0 + // Apply Givens matrices to U. + blas64.Rot( + blas64.Vector{N: n - i, Data: work[i:n], Inc: 1}, + blas64.Vector{N: n - i, Data: umat.Data[i*stride+i : i*stride+n], Inc: 1}, + cos[i], sin[i]) + if umat.Data[i*stride+i] == 0 { + // The matrix is singular (may rarely happen due to + // floating-point effects?). + ok = false + } else if umat.Data[i*stride+i] < 0 { + // Diagonal elements should be positive. If it happens + // that on the i-th row the diagonal is negative, + // multiply U from the left by an identity matrix that + // has -1 on the i-th row. + blas64.Scal(-1, blas64.Vector{N: n - i, Data: umat.Data[i*stride+i : i*stride+n], Inc: 1}) + } + } + if ok { + c.chol.Copy(workMat) + c.updateCond(-1) + } + return ok +} + +func (c *Cholesky) valid() bool { + return c.chol != nil && !c.chol.IsEmpty() +} + +// BandCholesky is a symmetric positive-definite band matrix represented by its +// Cholesky decomposition. +// +// Note that this matrix representation is useful for certain operations, in +// particular finding solutions to linear equations. It is very inefficient at +// other operations, in particular At is slow. +// +// BandCholesky methods may only be called on a value that has been successfully +// initialized by a call to Factorize that has returned true. Calls to methods +// of an unsuccessful Cholesky factorization will panic. +type BandCholesky struct { + // The chol pointer must never be retained as a pointer outside the Cholesky + // struct, either by returning chol outside the struct or by setting it to + // a pointer coming from outside. The same prohibition applies to the data + // slice within chol. + chol *TriBandDense + cond float64 +} + +// Factorize calculates the Cholesky decomposition of the matrix A and returns +// whether the matrix is positive definite. If Factorize returns false, the +// factorization must not be used. +func (ch *BandCholesky) Factorize(a SymBanded) (ok bool) { + n, k := a.SymBand() + if ch.chol == nil { + ch.chol = NewTriBandDense(n, k, Upper, nil) + } else { + ch.chol.Reset() + ch.chol.ReuseAsTriBand(n, k, Upper) + } + copySymBandIntoTriBand(ch.chol, a) + cSym := blas64.SymmetricBand{ + Uplo: blas.Upper, + N: n, + K: k, + Data: ch.chol.RawTriBand().Data, + Stride: ch.chol.RawTriBand().Stride, + } + _, ok = lapack64.Pbtrf(cSym) + if !ok { + ch.Reset() + return false + } + work := getFloat64s(3*n, false) + iwork := getInts(n, false) + aNorm := lapack64.Lansb(CondNorm, cSym, work) + ch.cond = 1 / lapack64.Pbcon(cSym, aNorm, work, iwork) + putInts(iwork) + putFloat64s(work) + return true +} + +// SolveTo finds the matrix X that solves A * X = B where A is represented by +// the Cholesky decomposition. The result is stored in-place into dst. +// If the Cholesky decomposition is singular or near-singular a Condition error +// is returned. See the documentation for Condition for more information. +func (ch *BandCholesky) SolveTo(dst *Dense, b Matrix) error { + if !ch.valid() { + panic(badCholesky) + } + br, bc := b.Dims() + if br != ch.chol.mat.N { + panic(ErrShape) + } + dst.reuseAsNonZeroed(br, bc) + if b != dst { + dst.Copy(b) + } + lapack64.Pbtrs(ch.chol.mat, dst.mat) + if ch.cond > ConditionTolerance { + return Condition(ch.cond) + } + return nil +} + +// SolveVecTo finds the vector x that solves A * x = b where A is represented by +// the Cholesky decomposition. The result is stored in-place into dst. +// If the Cholesky decomposition is singular or near-singular a Condition error +// is returned. See the documentation for Condition for more information. +func (ch *BandCholesky) SolveVecTo(dst *VecDense, b Vector) error { + if !ch.valid() { + panic(badCholesky) + } + n := ch.chol.mat.N + if br, bc := b.Dims(); br != n || bc != 1 { + panic(ErrShape) + } + if b, ok := b.(RawVectorer); ok && dst != b { + dst.checkOverlap(b.RawVector()) + } + dst.reuseAsNonZeroed(n) + if dst != b { + dst.CopyVec(b) + } + lapack64.Pbtrs(ch.chol.mat, dst.asGeneral()) + if ch.cond > ConditionTolerance { + return Condition(ch.cond) + } + return nil +} + +// Cond returns the condition number of the factorized matrix. +func (ch *BandCholesky) Cond() float64 { + if !ch.valid() { + panic(badCholesky) + } + return ch.cond +} + +// Reset resets the factorization so that it can be reused as the receiver of +// a dimensionally restricted operation. +func (ch *BandCholesky) Reset() { + if ch.chol != nil { + ch.chol.Reset() + } + ch.cond = math.Inf(1) +} + +// Dims returns the dimensions of the matrix. +func (ch *BandCholesky) Dims() (r, c int) { + if !ch.valid() { + panic(badCholesky) + } + r, c = ch.chol.Dims() + return r, c +} + +// At returns the element at row i, column j. +func (ch *BandCholesky) At(i, j int) float64 { + if !ch.valid() { + panic(badCholesky) + } + n, k, _ := ch.chol.TriBand() + if uint(i) >= uint(n) { + panic(ErrRowAccess) + } + if uint(j) >= uint(n) { + panic(ErrColAccess) + } + + if i > j { + i, j = j, i + } + if j-i > k { + return 0 + } + var aij float64 + for k := max(0, j-k); k <= i; k++ { + aij += ch.chol.at(k, i) * ch.chol.at(k, j) + } + return aij +} + +// T returns the receiver, the transpose of a symmetric matrix. +func (ch *BandCholesky) T() Matrix { + return ch +} + +// TBand returns the receiver, the transpose of a symmetric band matrix. +func (ch *BandCholesky) TBand() Banded { + return ch +} + +// SymmetricDim implements the Symmetric interface and returns the number of rows +// in the matrix (this is also the number of columns). +func (ch *BandCholesky) SymmetricDim() int { + n, _ := ch.chol.Triangle() + return n +} + +// Bandwidth returns the lower and upper bandwidth values for the matrix. +// The total bandwidth of the matrix is kl+ku+1. +func (ch *BandCholesky) Bandwidth() (kl, ku int) { + _, k, _ := ch.chol.TriBand() + return k, k +} + +// SymBand returns the number of rows/columns in the matrix, and the size of the +// bandwidth. The total bandwidth of the matrix is 2*k+1. +func (ch *BandCholesky) SymBand() (n, k int) { + n, k, _ = ch.chol.TriBand() + return n, k +} + +// IsEmpty returns whether the receiver is empty. Empty matrices can be the +// receiver for dimensionally restricted operations. The receiver can be emptied +// using Reset. +func (ch *BandCholesky) IsEmpty() bool { + return ch == nil || ch.chol.IsEmpty() +} + +// Det returns the determinant of the matrix that has been factorized. +func (ch *BandCholesky) Det() float64 { + if !ch.valid() { + panic(badCholesky) + } + return math.Exp(ch.LogDet()) +} + +// LogDet returns the log of the determinant of the matrix that has been factorized. +func (ch *BandCholesky) LogDet() float64 { + if !ch.valid() { + panic(badCholesky) + } + var det float64 + for i := 0; i < ch.chol.mat.N; i++ { + det += 2 * math.Log(ch.chol.mat.Data[i*ch.chol.mat.Stride]) + } + return det +} + +func (ch *BandCholesky) valid() bool { + return ch.chol != nil && !ch.chol.IsEmpty() +} diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_server.go b/vendor/google.golang.org/grpc/internal/transport/http2_server.go new file mode 100644 index 0000000000..8d3a353c1d --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/transport/http2_server.go @@ -0,0 +1,1469 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "math" + "net" + "net/http" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/syscall" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" +) + +var ( + // ErrIllegalHeaderWrite indicates that setting header is illegal because of + // the stream's state. + ErrIllegalHeaderWrite = status.Error(codes.Internal, "transport: SendHeader called multiple times") + // ErrHeaderListSizeLimitViolation indicates that the header list size is larger + // than the limit set by peer. + ErrHeaderListSizeLimitViolation = status.Error(codes.Internal, "transport: trying to send header list size larger than the limit set by peer") +) + +// serverConnectionCounter counts the number of connections a server has seen +// (equal to the number of http2Servers created). Must be accessed atomically. +var serverConnectionCounter uint64 + +// http2Server implements the ServerTransport interface with HTTP2. +type http2Server struct { + lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. + ctx context.Context + done chan struct{} + conn net.Conn + loopy *loopyWriter + readerDone chan struct{} // sync point to enable testing. + writerDone chan struct{} // sync point to enable testing. + remoteAddr net.Addr + localAddr net.Addr + authInfo credentials.AuthInfo // auth info about the connection + inTapHandle tap.ServerInHandle + framer *framer + // The max number of concurrent streams. + maxStreams uint32 + // controlBuf delivers all the control related tasks (e.g., window + // updates, reset streams, and various settings) to the controller. + controlBuf *controlBuffer + fc *trInFlow + stats []stats.Handler + // Keepalive and max-age parameters for the server. + kp keepalive.ServerParameters + // Keepalive enforcement policy. + kep keepalive.EnforcementPolicy + // The time instance last ping was received. + lastPingAt time.Time + // Number of times the client has violated keepalive ping policy so far. + pingStrikes uint8 + // Flag to signify that number of ping strikes should be reset to 0. + // This is set whenever data or header frames are sent. + // 1 means yes. + resetPingStrikes uint32 // Accessed atomically. + initialWindowSize int32 + bdpEst *bdpEstimator + maxSendHeaderListSize *uint32 + + mu sync.Mutex // guard the following + + // drainEvent is initialized when Drain() is called the first time. After + // which the server writes out the first GoAway(with ID 2^31-1) frame. Then + // an independent goroutine will be launched to later send the second + // GoAway. During this time we don't want to write another first GoAway(with + // ID 2^31 -1) frame. Thus call to Drain() will be a no-op if drainEvent is + // already initialized since draining is already underway. + drainEvent *grpcsync.Event + state transportState + activeStreams map[uint32]*Stream + // idle is the time instant when the connection went idle. + // This is either the beginning of the connection or when the number of + // RPCs go down to 0. + // When the connection is busy, this value is set to 0. + idle time.Time + + // Fields below are for channelz metric collection. + channelzID *channelz.Identifier + czData *channelzData + bufferPool *bufferPool + + connectionID uint64 + + // maxStreamMu guards the maximum stream ID + // This lock may not be taken if mu is already held. + maxStreamMu sync.Mutex + maxStreamID uint32 // max stream ID ever seen + + logger *grpclog.PrefixLogger +} + +// NewServerTransport creates a http2 transport with conn and configuration +// options from config. +// +// It returns a non-nil transport and a nil error on success. On failure, it +// returns a nil transport and a non-nil error. For a special case where the +// underlying conn gets closed before the client preface could be read, it +// returns a nil transport and a nil error. +func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { + var authInfo credentials.AuthInfo + rawConn := conn + if config.Credentials != nil { + var err error + conn, authInfo, err = config.Credentials.ServerHandshake(rawConn) + if err != nil { + // ErrConnDispatched means that the connection was dispatched away + // from gRPC; those connections should be left open. io.EOF means + // the connection was closed before handshaking completed, which can + // happen naturally from probers. Return these errors directly. + if err == credentials.ErrConnDispatched || err == io.EOF { + return nil, err + } + return nil, connectionErrorf(false, err, "ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) + } + } + writeBufSize := config.WriteBufferSize + readBufSize := config.ReadBufferSize + maxHeaderListSize := defaultServerMaxHeaderListSize + if config.MaxHeaderListSize != nil { + maxHeaderListSize = *config.MaxHeaderListSize + } + framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize) + // Send initial settings as connection preface to client. + isettings := []http2.Setting{{ + ID: http2.SettingMaxFrameSize, + Val: http2MaxFrameLen, + }} + // TODO(zhaoq): Have a better way to signal "no limit" because 0 is + // permitted in the HTTP2 spec. + maxStreams := config.MaxStreams + if maxStreams == 0 { + maxStreams = math.MaxUint32 + } else { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingMaxConcurrentStreams, + Val: maxStreams, + }) + } + dynamicWindow := true + iwz := int32(initialWindowSize) + if config.InitialWindowSize >= defaultWindowSize { + iwz = config.InitialWindowSize + dynamicWindow = false + } + icwz := int32(initialWindowSize) + if config.InitialConnWindowSize >= defaultWindowSize { + icwz = config.InitialConnWindowSize + dynamicWindow = false + } + if iwz != defaultWindowSize { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingInitialWindowSize, + Val: uint32(iwz)}) + } + if config.MaxHeaderListSize != nil { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingMaxHeaderListSize, + Val: *config.MaxHeaderListSize, + }) + } + if config.HeaderTableSize != nil { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingHeaderTableSize, + Val: *config.HeaderTableSize, + }) + } + if err := framer.fr.WriteSettings(isettings...); err != nil { + return nil, connectionErrorf(false, err, "transport: %v", err) + } + // Adjust the connection flow control window if needed. + if delta := uint32(icwz - defaultWindowSize); delta > 0 { + if err := framer.fr.WriteWindowUpdate(0, delta); err != nil { + return nil, connectionErrorf(false, err, "transport: %v", err) + } + } + kp := config.KeepaliveParams + if kp.MaxConnectionIdle == 0 { + kp.MaxConnectionIdle = defaultMaxConnectionIdle + } + if kp.MaxConnectionAge == 0 { + kp.MaxConnectionAge = defaultMaxConnectionAge + } + // Add a jitter to MaxConnectionAge. + kp.MaxConnectionAge += getJitter(kp.MaxConnectionAge) + if kp.MaxConnectionAgeGrace == 0 { + kp.MaxConnectionAgeGrace = defaultMaxConnectionAgeGrace + } + if kp.Time == 0 { + kp.Time = defaultServerKeepaliveTime + } + if kp.Timeout == 0 { + kp.Timeout = defaultServerKeepaliveTimeout + } + if kp.Time != infinity { + if err = syscall.SetTCPUserTimeout(rawConn, kp.Timeout); err != nil { + return nil, connectionErrorf(false, err, "transport: failed to set TCP_USER_TIMEOUT: %v", err) + } + } + kep := config.KeepalivePolicy + if kep.MinTime == 0 { + kep.MinTime = defaultKeepalivePolicyMinTime + } + + done := make(chan struct{}) + t := &http2Server{ + ctx: setConnection(context.Background(), rawConn), + done: done, + conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), + authInfo: authInfo, + framer: framer, + readerDone: make(chan struct{}), + writerDone: make(chan struct{}), + maxStreams: maxStreams, + inTapHandle: config.InTapHandle, + fc: &trInFlow{limit: uint32(icwz)}, + state: reachable, + activeStreams: make(map[uint32]*Stream), + stats: config.StatsHandlers, + kp: kp, + idle: time.Now(), + kep: kep, + initialWindowSize: iwz, + czData: new(channelzData), + bufferPool: newBufferPool(), + } + t.logger = prefixLoggerForServerTransport(t) + // Add peer information to the http2server context. + t.ctx = peer.NewContext(t.ctx, t.getPeer()) + + t.controlBuf = newControlBuffer(t.done) + if dynamicWindow { + t.bdpEst = &bdpEstimator{ + bdp: initialWindowSize, + updateFlowControl: t.updateFlowControl, + } + } + for _, sh := range t.stats { + t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{} + sh.HandleConn(t.ctx, connBegin) + } + t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr)) + if err != nil { + return nil, err + } + + t.connectionID = atomic.AddUint64(&serverConnectionCounter, 1) + t.framer.writer.Flush() + + defer func() { + if err != nil { + t.Close(err) + } + }() + + // Check the validity of client preface. + preface := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(t.conn, preface); err != nil { + // In deployments where a gRPC server runs behind a cloud load balancer + // which performs regular TCP level health checks, the connection is + // closed immediately by the latter. Returning io.EOF here allows the + // grpc server implementation to recognize this scenario and suppress + // logging to reduce spam. + if err == io.EOF { + return nil, io.EOF + } + return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) + } + if !bytes.Equal(preface, clientPreface) { + return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams received bogus greeting from client: %q", preface) + } + + frame, err := t.framer.fr.ReadFrame() + if err == io.EOF || err == io.ErrUnexpectedEOF { + return nil, err + } + if err != nil { + return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to read initial settings frame: %v", err) + } + atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) + sf, ok := frame.(*http2.SettingsFrame) + if !ok { + return nil, connectionErrorf(false, nil, "transport: http2Server.HandleStreams saw invalid preface type %T from client", frame) + } + t.handleSettings(sf) + + go func() { + t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) + t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler + t.loopy.run() + close(t.writerDone) + }() + go t.keepalive() + return t, nil +} + +// operateHeaders takes action on the decoded headers. Returns an error if fatal +// error encountered and transport needs to close, otherwise returns nil. +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) error { + // Acquire max stream ID lock for entire duration + t.maxStreamMu.Lock() + defer t.maxStreamMu.Unlock() + + streamID := frame.Header().StreamID + + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeFrameSize, + onWrite: func() {}, + }) + return nil + } + + if streamID%2 != 1 || streamID <= t.maxStreamID { + // illegal gRPC stream id. + return fmt.Errorf("received an illegal stream id: %v. headers frame: %+v", streamID, frame) + } + t.maxStreamID = streamID + + buf := newRecvBuffer() + s := &Stream{ + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + } + var ( + // if false, content-type was missing or invalid + isGRPC = false + contentType = "" + mdata = make(metadata.MD, len(frame.Fields)) + httpMethod string + // these are set if an error is encountered while parsing the headers + protocolError bool + headerError *status.Status + + timeoutSet bool + timeout time.Duration + ) + + for _, hf := range frame.Fields { + switch hf.Name { + case "content-type": + contentSubtype, validContentType := grpcutil.ContentSubtype(hf.Value) + if !validContentType { + contentType = hf.Value + break + } + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + s.contentSubtype = contentSubtype + isGRPC = true + + case "grpc-accept-encoding": + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + if hf.Value == "" { + continue + } + compressors := hf.Value + if s.clientAdvertisedCompressors != "" { + compressors = s.clientAdvertisedCompressors + "," + compressors + } + s.clientAdvertisedCompressors = compressors + case "grpc-encoding": + s.recvCompress = hf.Value + case ":method": + httpMethod = hf.Value + case ":path": + s.method = hf.Value + case "grpc-timeout": + timeoutSet = true + var err error + if timeout, err = decodeTimeout(hf.Value); err != nil { + headerError = status.Newf(codes.Internal, "malformed grpc-timeout: %v", err) + } + // "Transports must consider requests containing the Connection header + // as malformed." - A41 + case "connection": + if t.logger.V(logLevel) { + t.logger.Infof("Received a HEADERS frame with a :connection header which makes the request malformed, as per the HTTP/2 spec") + } + protocolError = true + default: + if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { + break + } + v, err := decodeMetadataHeader(hf.Name, hf.Value) + if err != nil { + headerError = status.Newf(codes.Internal, "malformed binary metadata %q in header %q: %v", hf.Value, hf.Name, err) + t.logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) + break + } + mdata[hf.Name] = append(mdata[hf.Name], v) + } + } + + // "If multiple Host headers or multiple :authority headers are present, the + // request must be rejected with an HTTP status code 400 as required by Host + // validation in RFC 7230 §5.4, gRPC status code INTERNAL, or RST_STREAM + // with HTTP/2 error code PROTOCOL_ERROR." - A41. Since this is a HTTP/2 + // error, this takes precedence over a client not speaking gRPC. + if len(mdata[":authority"]) > 1 || len(mdata["host"]) > 1 { + errMsg := fmt.Sprintf("num values of :authority: %v, num values of host: %v, both must only have 1 value as per HTTP/2 spec", len(mdata[":authority"]), len(mdata["host"])) + if t.logger.V(logLevel) { + t.logger.Infof("Aborting the stream early: %v", errMsg) + } + t.controlBuf.put(&earlyAbortStream{ + httpStatus: http.StatusBadRequest, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: status.New(codes.Internal, errMsg), + rst: !frame.StreamEnded(), + }) + return nil + } + + if protocolError { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeProtocol, + onWrite: func() {}, + }) + return nil + } + if !isGRPC { + t.controlBuf.put(&earlyAbortStream{ + httpStatus: http.StatusUnsupportedMediaType, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType), + rst: !frame.StreamEnded(), + }) + return nil + } + if headerError != nil { + t.controlBuf.put(&earlyAbortStream{ + httpStatus: http.StatusBadRequest, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: headerError, + rst: !frame.StreamEnded(), + }) + return nil + } + + // "If :authority is missing, Host must be renamed to :authority." - A41 + if len(mdata[":authority"]) == 0 { + // No-op if host isn't present, no eventual :authority header is a valid + // RPC. + if host, ok := mdata["host"]; ok { + mdata[":authority"] = host + delete(mdata, "host") + } + } else { + // "If :authority is present, Host must be discarded" - A41 + delete(mdata, "host") + } + + if frame.StreamEnded() { + // s is just created by the caller. No lock needed. + s.state = streamReadDone + } + if timeoutSet { + s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout) + } else { + s.ctx, s.cancel = context.WithCancel(t.ctx) + } + + // Attach the received metadata to the context. + if len(mdata) > 0 { + s.ctx = metadata.NewIncomingContext(s.ctx, mdata) + if statsTags := mdata["grpc-tags-bin"]; len(statsTags) > 0 { + s.ctx = stats.SetIncomingTags(s.ctx, []byte(statsTags[len(statsTags)-1])) + } + if statsTrace := mdata["grpc-trace-bin"]; len(statsTrace) > 0 { + s.ctx = stats.SetIncomingTrace(s.ctx, []byte(statsTrace[len(statsTrace)-1])) + } + } + t.mu.Lock() + if t.state != reachable { + t.mu.Unlock() + s.cancel() + return nil + } + if uint32(len(t.activeStreams)) >= t.maxStreams { + t.mu.Unlock() + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeRefusedStream, + onWrite: func() {}, + }) + s.cancel() + return nil + } + if httpMethod != http.MethodPost { + t.mu.Unlock() + errMsg := fmt.Sprintf("Received a HEADERS frame with :method %q which should be POST", httpMethod) + if t.logger.V(logLevel) { + t.logger.Infof("Aborting the stream early: %v", errMsg) + } + t.controlBuf.put(&earlyAbortStream{ + httpStatus: 405, + streamID: streamID, + contentSubtype: s.contentSubtype, + status: status.New(codes.Internal, errMsg), + rst: !frame.StreamEnded(), + }) + s.cancel() + return nil + } + if t.inTapHandle != nil { + var err error + if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: s.method}); err != nil { + t.mu.Unlock() + if t.logger.V(logLevel) { + t.logger.Infof("Aborting the stream early due to InTapHandle failure: %v", err) + } + stat, ok := status.FromError(err) + if !ok { + stat = status.New(codes.PermissionDenied, err.Error()) + } + t.controlBuf.put(&earlyAbortStream{ + httpStatus: 200, + streamID: s.id, + contentSubtype: s.contentSubtype, + status: stat, + rst: !frame.StreamEnded(), + }) + return nil + } + } + t.activeStreams[streamID] = s + if len(t.activeStreams) == 1 { + t.idle = time.Time{} + } + t.mu.Unlock() + if channelz.IsOn() { + atomic.AddInt64(&t.czData.streamsStarted, 1) + atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano()) + } + s.requestRead = func(n int) { + t.adjustWindow(s, uint32(n)) + } + s.ctx = traceCtx(s.ctx, s.method) + for _, sh := range t.stats { + s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: s.recvCompress, + WireLength: int(frame.Header().Length), + Header: mdata.Copy(), + } + sh.HandleRPC(s.ctx, inHeader) + } + s.ctxDone = s.ctx.Done() + s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) + s.trReader = &transportReader{ + reader: &recvBufferReader{ + ctx: s.ctx, + ctxDone: s.ctxDone, + recv: s.buf, + freeBuffer: t.bufferPool.put, + }, + windowHandler: func(n int) { + t.updateWindow(s, uint32(n)) + }, + } + // Register the stream with loopy. + t.controlBuf.put(®isterStream{ + streamID: s.id, + wq: s.wq, + }) + handle(s) + return nil +} + +// HandleStreams receives incoming streams using the given handler. This is +// typically run in a separate goroutine. +// traceCtx attaches trace to ctx and returns the new context. +func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { + defer close(t.readerDone) + for { + t.controlBuf.throttle() + frame, err := t.framer.fr.ReadFrame() + atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) + if err != nil { + if se, ok := err.(http2.StreamError); ok { + if t.logger.V(logLevel) { + t.logger.Warningf("Encountered http2.StreamError: %v", se) + } + t.mu.Lock() + s := t.activeStreams[se.StreamID] + t.mu.Unlock() + if s != nil { + t.closeStream(s, true, se.Code, false) + } else { + t.controlBuf.put(&cleanupStream{ + streamID: se.StreamID, + rst: true, + rstCode: se.Code, + onWrite: func() {}, + }) + } + continue + } + if err == io.EOF || err == io.ErrUnexpectedEOF { + t.Close(err) + return + } + t.Close(err) + return + } + switch frame := frame.(type) { + case *http2.MetaHeadersFrame: + if err := t.operateHeaders(frame, handle, traceCtx); err != nil { + t.Close(err) + break + } + case *http2.DataFrame: + t.handleData(frame) + case *http2.RSTStreamFrame: + t.handleRSTStream(frame) + case *http2.SettingsFrame: + t.handleSettings(frame) + case *http2.PingFrame: + t.handlePing(frame) + case *http2.WindowUpdateFrame: + t.handleWindowUpdate(frame) + case *http2.GoAwayFrame: + // TODO: Handle GoAway from the client appropriately. + default: + if t.logger.V(logLevel) { + t.logger.Infof("Received unsupported frame type %T", frame) + } + } + } +} + +func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { + t.mu.Lock() + defer t.mu.Unlock() + if t.activeStreams == nil { + // The transport is closing. + return nil, false + } + s, ok := t.activeStreams[f.Header().StreamID] + if !ok { + // The stream is already done. + return nil, false + } + return s, true +} + +// adjustWindow sends out extra window update over the initial window size +// of stream if the application is requesting data larger in size than +// the window. +func (t *http2Server) adjustWindow(s *Stream, n uint32) { + if w := s.fc.maybeAdjust(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, increment: w}) + } + +} + +// updateWindow adjusts the inbound quota for the stream and the transport. +// Window updates will deliver to the controller for sending when +// the cumulative quota exceeds the corresponding threshold. +func (t *http2Server) updateWindow(s *Stream, n uint32) { + if w := s.fc.onRead(n); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{streamID: s.id, + increment: w, + }) + } +} + +// updateFlowControl updates the incoming flow control windows +// for the transport and the stream based on the current bdp +// estimation. +func (t *http2Server) updateFlowControl(n uint32) { + t.mu.Lock() + for _, s := range t.activeStreams { + s.fc.newLimit(n) + } + t.initialWindowSize = int32(n) + t.mu.Unlock() + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: t.fc.newLimit(n), + }) + t.controlBuf.put(&outgoingSettings{ + ss: []http2.Setting{ + { + ID: http2.SettingInitialWindowSize, + Val: n, + }, + }, + }) + +} + +func (t *http2Server) handleData(f *http2.DataFrame) { + size := f.Header().Length + var sendBDPPing bool + if t.bdpEst != nil { + sendBDPPing = t.bdpEst.add(size) + } + // Decouple connection's flow control from application's read. + // An update on connection's flow control should not depend on + // whether user application has read the data or not. Such a + // restriction is already imposed on the stream's flow control, + // and therefore the sender will be blocked anyways. + // Decoupling the connection flow control will prevent other + // active(fast) streams from starving in presence of slow or + // inactive streams. + if w := t.fc.onData(size); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + if sendBDPPing { + // Avoid excessive ping detection (e.g. in an L7 proxy) + // by sending a window update prior to the BDP ping. + if w := t.fc.reset(); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{ + streamID: 0, + increment: w, + }) + } + t.controlBuf.put(bdpPing) + } + // Select the right stream to dispatch. + s, ok := t.getStream(f) + if !ok { + return + } + if s.getState() == streamReadDone { + t.closeStream(s, true, http2.ErrCodeStreamClosed, false) + return + } + if size > 0 { + if err := s.fc.onData(size); err != nil { + t.closeStream(s, true, http2.ErrCodeFlowControl, false) + return + } + if f.Header().Flags.Has(http2.FlagDataPadded) { + if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { + t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) + } + } + // TODO(bradfitz, zhaoq): A copy is required here because there is no + // guarantee f.Data() is consumed before the arrival of next frame. + // Can this copy be eliminated? + if len(f.Data()) > 0 { + buffer := t.bufferPool.get() + buffer.Reset() + buffer.Write(f.Data()) + s.write(recvMsg{buffer: buffer}) + } + } + if f.StreamEnded() { + // Received the end of stream from the client. + s.compareAndSwapState(streamActive, streamReadDone) + s.write(recvMsg{err: io.EOF}) + } +} + +func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { + // If the stream is not deleted from the transport's active streams map, then do a regular close stream. + if s, ok := t.getStream(f); ok { + t.closeStream(s, false, 0, false) + return + } + // If the stream is already deleted from the active streams map, then put a cleanupStream item into controlbuf to delete the stream from loopy writer's established streams map. + t.controlBuf.put(&cleanupStream{ + streamID: f.Header().StreamID, + rst: false, + rstCode: 0, + onWrite: func() {}, + }) +} + +func (t *http2Server) handleSettings(f *http2.SettingsFrame) { + if f.IsAck() { + return + } + var ss []http2.Setting + var updateFuncs []func() + f.ForeachSetting(func(s http2.Setting) error { + switch s.ID { + case http2.SettingMaxHeaderListSize: + updateFuncs = append(updateFuncs, func() { + t.maxSendHeaderListSize = new(uint32) + *t.maxSendHeaderListSize = s.Val + }) + default: + ss = append(ss, s) + } + return nil + }) + t.controlBuf.executeAndPut(func(any) bool { + for _, f := range updateFuncs { + f() + } + return true + }, &incomingSettings{ + ss: ss, + }) +} + +const ( + maxPingStrikes = 2 + defaultPingTimeout = 2 * time.Hour +) + +func (t *http2Server) handlePing(f *http2.PingFrame) { + if f.IsAck() { + if f.Data == goAwayPing.data && t.drainEvent != nil { + t.drainEvent.Fire() + return + } + // Maybe it's a BDP ping. + if t.bdpEst != nil { + t.bdpEst.calculate(f.Data) + } + return + } + pingAck := &ping{ack: true} + copy(pingAck.data[:], f.Data[:]) + t.controlBuf.put(pingAck) + + now := time.Now() + defer func() { + t.lastPingAt = now + }() + // A reset ping strikes means that we don't need to check for policy + // violation for this ping and the pingStrikes counter should be set + // to 0. + if atomic.CompareAndSwapUint32(&t.resetPingStrikes, 1, 0) { + t.pingStrikes = 0 + return + } + t.mu.Lock() + ns := len(t.activeStreams) + t.mu.Unlock() + if ns < 1 && !t.kep.PermitWithoutStream { + // Keepalive shouldn't be active thus, this new ping should + // have come after at least defaultPingTimeout. + if t.lastPingAt.Add(defaultPingTimeout).After(now) { + t.pingStrikes++ + } + } else { + // Check if keepalive policy is respected. + if t.lastPingAt.Add(t.kep.MinTime).After(now) { + t.pingStrikes++ + } + } + + if t.pingStrikes > maxPingStrikes { + // Send goaway and close the connection. + t.controlBuf.put(&goAway{code: http2.ErrCodeEnhanceYourCalm, debugData: []byte("too_many_pings"), closeConn: errors.New("got too many pings from the client")}) + } +} + +func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { + t.controlBuf.put(&incomingWindowUpdate{ + streamID: f.Header().StreamID, + increment: f.Increment, + }) +} + +func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) []hpack.HeaderField { + for k, vv := range md { + if isReservedHeader(k) { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + continue + } + for _, v := range vv { + headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) + } + } + return headerFields +} + +func (t *http2Server) checkForHeaderListSize(it any) bool { + if t.maxSendHeaderListSize == nil { + return true + } + hdrFrame := it.(*headerFrame) + var sz int64 + for _, f := range hdrFrame.hf { + if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { + if t.logger.V(logLevel) { + t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize) + } + return false + } + } + return true +} + +func (t *http2Server) streamContextErr(s *Stream) error { + select { + case <-t.done: + return ErrConnClosing + default: + } + return ContextErr(s.ctx.Err()) +} + +// WriteHeader sends the header metadata md back to the client. +func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { + s.hdrMu.Lock() + defer s.hdrMu.Unlock() + if s.getState() == streamDone { + return t.streamContextErr(s) + } + + if s.updateHeaderSent() { + return ErrIllegalHeaderWrite + } + + if md.Len() > 0 { + if s.header.Len() > 0 { + s.header = metadata.Join(s.header, md) + } else { + s.header = md + } + } + if err := t.writeHeaderLocked(s); err != nil { + return status.Convert(err).Err() + } + return nil +} + +func (t *http2Server) setResetPingStrikes() { + atomic.StoreUint32(&t.resetPingStrikes, 1) +} + +func (t *http2Server) writeHeaderLocked(s *Stream) error { + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)}) + if s.sendCompress != "" { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + } + headerFields = appendHeaderFieldsFromMD(headerFields, s.header) + success, err := t.controlBuf.executeAndPut(t.checkForHeaderListSize, &headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: false, + onWrite: t.setResetPingStrikes, + }) + if !success { + if err != nil { + return err + } + t.closeStream(s, true, http2.ErrCodeInternal, false) + return ErrHeaderListSizeLimitViolation + } + for _, sh := range t.stats { + // Note: Headers are compressed with hpack after this call returns. + // No WireLength field is set here. + outHeader := &stats.OutHeader{ + Header: s.header.Copy(), + Compression: s.sendCompress, + } + sh.HandleRPC(s.Context(), outHeader) + } + return nil +} + +// WriteStatus sends stream status to the client and terminates the stream. +// There is no further I/O operations being able to perform on this stream. +// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early +// OK is adopted. +func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { + s.hdrMu.Lock() + defer s.hdrMu.Unlock() + + if s.getState() == streamDone { + return nil + } + + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. + if !s.updateHeaderSent() { // No headers have been sent. + if len(s.header) > 0 { // Send a separate header frame. + if err := t.writeHeaderLocked(s); err != nil { + return err + } + } else { // Send a trailer only response. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: grpcutil.ContentType(s.contentSubtype)}) + } + } + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) + + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + // TODO: return error instead, when callers are able to handle it. + t.logger.Errorf("Failed to marshal rpc status: %s, error: %v", pretty.ToJSON(p), err) + } else { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) + } + } + + // Attach the trailer metadata. + headerFields = appendHeaderFieldsFromMD(headerFields, s.trailer) + trailingHeader := &headerFrame{ + streamID: s.id, + hf: headerFields, + endStream: true, + onWrite: t.setResetPingStrikes, + } + + success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader) + if !success { + if err != nil { + return err + } + t.closeStream(s, true, http2.ErrCodeInternal, false) + return ErrHeaderListSizeLimitViolation + } + // Send a RST_STREAM after the trailers if the client has not already half-closed. + rst := s.getState() == streamActive + t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) + for _, sh := range t.stats { + // Note: The trailer fields are compressed with hpack after this call returns. + // No WireLength field is set here. + sh.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) + } + return nil +} + +// Write converts the data into HTTP2 data frame and sends it out. Non-nil error +// is returns if it fails (e.g., framing error, transport error). +func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { + if !s.isHeaderSent() { // Headers haven't been written yet. + if err := t.WriteHeader(s, nil); err != nil { + return err + } + } else { + // Writing headers checks for this condition. + if s.getState() == streamDone { + return t.streamContextErr(s) + } + } + df := &dataFrame{ + streamID: s.id, + h: hdr, + d: data, + onEachWrite: t.setResetPingStrikes, + } + if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { + return t.streamContextErr(s) + } + return t.controlBuf.put(df) +} + +// keepalive running in a separate goroutine does the following: +// 1. Gracefully closes an idle connection after a duration of keepalive.MaxConnectionIdle. +// 2. Gracefully closes any connection after a duration of keepalive.MaxConnectionAge. +// 3. Forcibly closes a connection after an additive period of keepalive.MaxConnectionAgeGrace over keepalive.MaxConnectionAge. +// 4. Makes sure a connection is alive by sending pings with a frequency of keepalive.Time and closes a non-responsive connection +// after an additional duration of keepalive.Timeout. +func (t *http2Server) keepalive() { + p := &ping{} + // True iff a ping has been sent, and no data has been received since then. + outstandingPing := false + // Amount of time remaining before which we should receive an ACK for the + // last sent ping. + kpTimeoutLeft := time.Duration(0) + // Records the last value of t.lastRead before we go block on the timer. + // This is required to check for read activity since then. + prevNano := time.Now().UnixNano() + // Initialize the different timers to their default values. + idleTimer := time.NewTimer(t.kp.MaxConnectionIdle) + ageTimer := time.NewTimer(t.kp.MaxConnectionAge) + kpTimer := time.NewTimer(t.kp.Time) + defer func() { + // We need to drain the underlying channel in these timers after a call + // to Stop(), only if we are interested in resetting them. Clearly we + // are not interested in resetting them here. + idleTimer.Stop() + ageTimer.Stop() + kpTimer.Stop() + }() + + for { + select { + case <-idleTimer.C: + t.mu.Lock() + idle := t.idle + if idle.IsZero() { // The connection is non-idle. + t.mu.Unlock() + idleTimer.Reset(t.kp.MaxConnectionIdle) + continue + } + val := t.kp.MaxConnectionIdle - time.Since(idle) + t.mu.Unlock() + if val <= 0 { + // The connection has been idle for a duration of keepalive.MaxConnectionIdle or more. + // Gracefully close the connection. + t.Drain("max_idle") + return + } + idleTimer.Reset(val) + case <-ageTimer.C: + t.Drain("max_age") + ageTimer.Reset(t.kp.MaxConnectionAgeGrace) + select { + case <-ageTimer.C: + // Close the connection after grace period. + if t.logger.V(logLevel) { + t.logger.Infof("Closing server transport due to maximum connection age") + } + t.controlBuf.put(closeConnection{}) + case <-t.done: + } + return + case <-kpTimer.C: + lastRead := atomic.LoadInt64(&t.lastRead) + if lastRead > prevNano { + // There has been read activity since the last time we were + // here. Setup the timer to fire at kp.Time seconds from + // lastRead time and continue. + outstandingPing = false + kpTimer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(time.Now().UnixNano())) + prevNano = lastRead + continue + } + if outstandingPing && kpTimeoutLeft <= 0 { + t.Close(fmt.Errorf("keepalive ping not acked within timeout %s", t.kp.Time)) + return + } + if !outstandingPing { + if channelz.IsOn() { + atomic.AddInt64(&t.czData.kpCount, 1) + } + t.controlBuf.put(p) + kpTimeoutLeft = t.kp.Timeout + outstandingPing = true + } + // The amount of time to sleep here is the minimum of kp.Time and + // timeoutLeft. This will ensure that we wait only for kp.Time + // before sending out the next ping (for cases where the ping is + // acked). + sleepDuration := minTime(t.kp.Time, kpTimeoutLeft) + kpTimeoutLeft -= sleepDuration + kpTimer.Reset(sleepDuration) + case <-t.done: + return + } + } +} + +// Close starts shutting down the http2Server transport. +// TODO(zhaoq): Now the destruction is not blocked on any pending streams. This +// could cause some resource issue. Revisit this later. +func (t *http2Server) Close(err error) { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return + } + if t.logger.V(logLevel) { + t.logger.Infof("Closing: %v", err) + } + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + t.controlBuf.finish() + close(t.done) + if err := t.conn.Close(); err != nil && t.logger.V(logLevel) { + t.logger.Infof("Error closing underlying net.Conn during Close: %v", err) + } + channelz.RemoveEntry(t.channelzID) + // Cancel all active streams. + for _, s := range streams { + s.cancel() + } + for _, sh := range t.stats { + connEnd := &stats.ConnEnd{} + sh.HandleConn(t.ctx, connEnd) + } +} + +// deleteStream deletes the stream s from transport's active streams. +func (t *http2Server) deleteStream(s *Stream, eosReceived bool) { + + t.mu.Lock() + if _, ok := t.activeStreams[s.id]; ok { + delete(t.activeStreams, s.id) + if len(t.activeStreams) == 0 { + t.idle = time.Now() + } + } + t.mu.Unlock() + + if channelz.IsOn() { + if eosReceived { + atomic.AddInt64(&t.czData.streamsSucceeded, 1) + } else { + atomic.AddInt64(&t.czData.streamsFailed, 1) + } + } +} + +// finishStream closes the stream and puts the trailing headerFrame into controlbuf. +func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) { + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel() + + oldState := s.swapState(streamDone) + if oldState == streamDone { + // If the stream was already done, return. + return + } + + hdr.cleanup = &cleanupStream{ + streamID: s.id, + rst: rst, + rstCode: rstCode, + onWrite: func() { + t.deleteStream(s, eosReceived) + }, + } + t.controlBuf.put(hdr) +} + +// closeStream clears the footprint of a stream when the stream is not needed any more. +func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), cancel needs to be + // called to interrupt the potential blocking on other goroutines. + s.cancel() + + s.swapState(streamDone) + t.deleteStream(s, eosReceived) + + t.controlBuf.put(&cleanupStream{ + streamID: s.id, + rst: rst, + rstCode: rstCode, + onWrite: func() {}, + }) +} + +func (t *http2Server) RemoteAddr() net.Addr { + return t.remoteAddr +} + +func (t *http2Server) Drain(debugData string) { + t.mu.Lock() + defer t.mu.Unlock() + if t.drainEvent != nil { + return + } + t.drainEvent = grpcsync.NewEvent() + t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte(debugData), headsUp: true}) +} + +var goAwayPing = &ping{data: [8]byte{1, 6, 1, 8, 0, 3, 3, 9}} + +// Handles outgoing GoAway and returns true if loopy needs to put itself +// in draining mode. +func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { + t.maxStreamMu.Lock() + t.mu.Lock() + if t.state == closing { // TODO(mmukhi): This seems unnecessary. + t.mu.Unlock() + t.maxStreamMu.Unlock() + // The transport is closing. + return false, ErrConnClosing + } + if !g.headsUp { + // Stop accepting more streams now. + t.state = draining + sid := t.maxStreamID + retErr := g.closeConn + if len(t.activeStreams) == 0 { + retErr = errors.New("second GOAWAY written and no active streams left to process") + } + t.mu.Unlock() + t.maxStreamMu.Unlock() + if err := t.framer.fr.WriteGoAway(sid, g.code, g.debugData); err != nil { + return false, err + } + if retErr != nil { + return false, retErr + } + return true, nil + } + t.mu.Unlock() + t.maxStreamMu.Unlock() + // For a graceful close, send out a GoAway with stream ID of MaxUInt32, + // Follow that with a ping and wait for the ack to come back or a timer + // to expire. During this time accept new streams since they might have + // originated before the GoAway reaches the client. + // After getting the ack or timer expiration send out another GoAway this + // time with an ID of the max stream server intends to process. + if err := t.framer.fr.WriteGoAway(math.MaxUint32, http2.ErrCodeNo, g.debugData); err != nil { + return false, err + } + if err := t.framer.fr.WritePing(false, goAwayPing.data); err != nil { + return false, err + } + go func() { + timer := time.NewTimer(time.Minute) + defer timer.Stop() + select { + case <-t.drainEvent.Done(): + case <-timer.C: + case <-t.done: + return + } + t.controlBuf.put(&goAway{code: g.code, debugData: g.debugData}) + }() + return false, nil +} + +func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric { + s := channelz.SocketInternalMetric{ + StreamsStarted: atomic.LoadInt64(&t.czData.streamsStarted), + StreamsSucceeded: atomic.LoadInt64(&t.czData.streamsSucceeded), + StreamsFailed: atomic.LoadInt64(&t.czData.streamsFailed), + MessagesSent: atomic.LoadInt64(&t.czData.msgSent), + MessagesReceived: atomic.LoadInt64(&t.czData.msgRecv), + KeepAlivesSent: atomic.LoadInt64(&t.czData.kpCount), + LastRemoteStreamCreatedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastStreamCreatedTime)), + LastMessageSentTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgSentTime)), + LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)), + LocalFlowControlWindow: int64(t.fc.getSize()), + SocketOptions: channelz.GetSocketOption(t.conn), + LocalAddr: t.localAddr, + RemoteAddr: t.remoteAddr, + // RemoteName : + } + if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok { + s.Security = au.GetSecurityValue() + } + s.RemoteFlowControlWindow = t.getOutFlowWindow() + return &s +} + +func (t *http2Server) IncrMsgSent() { + atomic.AddInt64(&t.czData.msgSent, 1) + atomic.StoreInt64(&t.czData.lastMsgSentTime, time.Now().UnixNano()) +} + +func (t *http2Server) IncrMsgRecv() { + atomic.AddInt64(&t.czData.msgRecv, 1) + atomic.StoreInt64(&t.czData.lastMsgRecvTime, time.Now().UnixNano()) +} + +func (t *http2Server) getOutFlowWindow() int64 { + resp := make(chan uint32, 1) + timer := time.NewTimer(time.Second) + defer timer.Stop() + t.controlBuf.put(&outFlowControlSizeRequest{resp}) + select { + case sz := <-resp: + return int64(sz) + case <-t.done: + return -1 + case <-timer.C: + return -2 + } +} + +func (t *http2Server) getPeer() *peer.Peer { + return &peer.Peer{ + Addr: t.remoteAddr, + AuthInfo: t.authInfo, // Can be nil + } +} + +func getJitter(v time.Duration) time.Duration { + if v == infinity { + return 0 + } + // Generate a jitter between +/- 10% of the value. + r := int64(v / 10) + j := grpcrand.Int63n(2*r) - r + return time.Duration(j) +} + +type connectionKey struct{} + +// GetConnection gets the connection from the context. +func GetConnection(ctx context.Context) net.Conn { + conn, _ := ctx.Value(connectionKey{}).(net.Conn) + return conn +} + +// SetConnection adds the connection to the context to be able to get +// information about the destination ip and port for an incoming RPC. This also +// allows any unary or streaming interceptors to see the connection. +func setConnection(ctx context.Context, conn net.Conn) context.Context { + return context.WithValue(ctx, connectionKey{}, conn) +} diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go new file mode 100644 index 0000000000..244123c6c5 --- /dev/null +++ b/vendor/google.golang.org/grpc/server.go @@ -0,0 +1,2093 @@ +/* + * + * Copyright 2014 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net" + "net/http" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/trace" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/binarylog" + "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/internal/transport" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" + "google.golang.org/grpc/tap" +) + +const ( + defaultServerMaxReceiveMessageSize = 1024 * 1024 * 4 + defaultServerMaxSendMessageSize = math.MaxInt32 + + // Server transports are tracked in a map which is keyed on listener + // address. For regular gRPC traffic, connections are accepted in Serve() + // through a call to Accept(), and we use the actual listener address as key + // when we add it to the map. But for connections received through + // ServeHTTP(), we do not have a listener and hence use this dummy value. + listenerAddressForServeHTTP = "listenerAddressForServeHTTP" +) + +func init() { + internal.GetServerCredentials = func(srv *Server) credentials.TransportCredentials { + return srv.opts.creds + } + internal.DrainServerTransports = func(srv *Server, addr string) { + srv.drainServerTransports(addr) + } + internal.AddGlobalServerOptions = func(opt ...ServerOption) { + globalServerOptions = append(globalServerOptions, opt...) + } + internal.ClearGlobalServerOptions = func() { + globalServerOptions = nil + } + internal.BinaryLogger = binaryLogger + internal.JoinServerOptions = newJoinServerOption +} + +var statusOK = status.New(codes.OK, "") +var logger = grpclog.Component("core") + +type methodHandler func(srv any, ctx context.Context, dec func(any) error, interceptor UnaryServerInterceptor) (any, error) + +// MethodDesc represents an RPC service's method specification. +type MethodDesc struct { + MethodName string + Handler methodHandler +} + +// ServiceDesc represents an RPC service's specification. +type ServiceDesc struct { + ServiceName string + // The pointer to the service interface. Used to check whether the user + // provided implementation satisfies the interface requirements. + HandlerType any + Methods []MethodDesc + Streams []StreamDesc + Metadata any +} + +// serviceInfo wraps information about a service. It is very similar to +// ServiceDesc and is constructed from it for internal purposes. +type serviceInfo struct { + // Contains the implementation for the methods in this service. + serviceImpl any + methods map[string]*MethodDesc + streams map[string]*StreamDesc + mdata any +} + +type serverWorkerData struct { + st transport.ServerTransport + wg *sync.WaitGroup + stream *transport.Stream +} + +// Server is a gRPC server to serve RPC requests. +type Server struct { + opts serverOptions + + mu sync.Mutex // guards following + lis map[net.Listener]bool + // conns contains all active server transports. It is a map keyed on a + // listener address with the value being the set of active transports + // belonging to that listener. + conns map[string]map[transport.ServerTransport]bool + serve bool + drain bool + cv *sync.Cond // signaled when connections close for GracefulStop + services map[string]*serviceInfo // service name -> service info + events trace.EventLog + + quit *grpcsync.Event + done *grpcsync.Event + channelzRemoveOnce sync.Once + serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop + + channelzID *channelz.Identifier + czData *channelzData + + serverWorkerChannel chan *serverWorkerData +} + +type serverOptions struct { + creds credentials.TransportCredentials + codec baseCodec + cp Compressor + dc Decompressor + unaryInt UnaryServerInterceptor + streamInt StreamServerInterceptor + chainUnaryInts []UnaryServerInterceptor + chainStreamInts []StreamServerInterceptor + binaryLogger binarylog.Logger + inTapHandle tap.ServerInHandle + statsHandlers []stats.Handler + maxConcurrentStreams uint32 + maxReceiveMessageSize int + maxSendMessageSize int + unknownStreamDesc *StreamDesc + keepaliveParams keepalive.ServerParameters + keepalivePolicy keepalive.EnforcementPolicy + initialWindowSize int32 + initialConnWindowSize int32 + writeBufferSize int + readBufferSize int + sharedWriteBuffer bool + connectionTimeout time.Duration + maxHeaderListSize *uint32 + headerTableSize *uint32 + numServerWorkers uint32 + recvBufferPool SharedBufferPool +} + +var defaultServerOptions = serverOptions{ + maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, + maxSendMessageSize: defaultServerMaxSendMessageSize, + connectionTimeout: 120 * time.Second, + writeBufferSize: defaultWriteBufSize, + readBufferSize: defaultReadBufSize, + recvBufferPool: nopBufferPool{}, +} +var globalServerOptions []ServerOption + +// A ServerOption sets options such as credentials, codec and keepalive parameters, etc. +type ServerOption interface { + apply(*serverOptions) +} + +// EmptyServerOption does not alter the server configuration. It can be embedded +// in another structure to build custom server options. +// +// # Experimental +// +// Notice: This type is EXPERIMENTAL and may be changed or removed in a +// later release. +type EmptyServerOption struct{} + +func (EmptyServerOption) apply(*serverOptions) {} + +// funcServerOption wraps a function that modifies serverOptions into an +// implementation of the ServerOption interface. +type funcServerOption struct { + f func(*serverOptions) +} + +func (fdo *funcServerOption) apply(do *serverOptions) { + fdo.f(do) +} + +func newFuncServerOption(f func(*serverOptions)) *funcServerOption { + return &funcServerOption{ + f: f, + } +} + +// joinServerOption provides a way to combine arbitrary number of server +// options into one. +type joinServerOption struct { + opts []ServerOption +} + +func (mdo *joinServerOption) apply(do *serverOptions) { + for _, opt := range mdo.opts { + opt.apply(do) + } +} + +func newJoinServerOption(opts ...ServerOption) ServerOption { + return &joinServerOption{opts: opts} +} + +// SharedWriteBuffer allows reusing per-connection transport write buffer. +// If this option is set to true every connection will release the buffer after +// flushing the data on the wire. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func SharedWriteBuffer(val bool) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.sharedWriteBuffer = val + }) +} + +// WriteBufferSize determines how much data can be batched before doing a write +// on the wire. The corresponding memory allocation for this buffer will be +// twice the size to keep syscalls low. The default value for this buffer is +// 32KB. Zero or negative values will disable the write buffer such that each +// write will be on underlying connection. +// Note: A Send call may not directly translate to a write. +func WriteBufferSize(s int) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.writeBufferSize = s + }) +} + +// ReadBufferSize lets you set the size of read buffer, this determines how much +// data can be read at most for one read syscall. The default value for this +// buffer is 32KB. Zero or negative values will disable read buffer for a +// connection so data framer can access the underlying conn directly. +func ReadBufferSize(s int) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.readBufferSize = s + }) +} + +// InitialWindowSize returns a ServerOption that sets window size for stream. +// The lower bound for window size is 64K and any value smaller than that will be ignored. +func InitialWindowSize(s int32) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.initialWindowSize = s + }) +} + +// InitialConnWindowSize returns a ServerOption that sets window size for a connection. +// The lower bound for window size is 64K and any value smaller than that will be ignored. +func InitialConnWindowSize(s int32) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.initialConnWindowSize = s + }) +} + +// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server. +func KeepaliveParams(kp keepalive.ServerParameters) ServerOption { + if kp.Time > 0 && kp.Time < internal.KeepaliveMinServerPingTime { + logger.Warning("Adjusting keepalive ping interval to minimum period of 1s") + kp.Time = internal.KeepaliveMinServerPingTime + } + + return newFuncServerOption(func(o *serverOptions) { + o.keepaliveParams = kp + }) +} + +// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server. +func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.keepalivePolicy = kep + }) +} + +// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. +// +// This will override any lookups by content-subtype for Codecs registered with RegisterCodec. +// +// Deprecated: register codecs using encoding.RegisterCodec. The server will +// automatically use registered codecs based on the incoming requests' headers. +// See also +// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec. +// Will be supported throughout 1.x. +func CustomCodec(codec Codec) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.codec = codec + }) +} + +// ForceServerCodec returns a ServerOption that sets a codec for message +// marshaling and unmarshaling. +// +// This will override any lookups by content-subtype for Codecs registered +// with RegisterCodec. +// +// See Content-Type on +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. Also see the documentation on RegisterCodec and +// CallContentSubtype for more details on the interaction between encoding.Codec +// and content-subtype. +// +// This function is provided for advanced users; prefer to register codecs +// using encoding.RegisterCodec. +// The server will automatically use registered codecs based on the incoming +// requests' headers. See also +// https://github.com/grpc/grpc-go/blob/master/Documentation/encoding.md#using-a-codec. +// Will be supported throughout 1.x. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func ForceServerCodec(codec encoding.Codec) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.codec = codec + }) +} + +// RPCCompressor returns a ServerOption that sets a compressor for outbound +// messages. For backward compatibility, all outbound messages will be sent +// using this compressor, regardless of incoming message compression. By +// default, server messages will be sent using the same compressor with which +// request messages were sent. +// +// Deprecated: use encoding.RegisterCompressor instead. Will be supported +// throughout 1.x. +func RPCCompressor(cp Compressor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.cp = cp + }) +} + +// RPCDecompressor returns a ServerOption that sets a decompressor for inbound +// messages. It has higher priority than decompressors registered via +// encoding.RegisterCompressor. +// +// Deprecated: use encoding.RegisterCompressor instead. Will be supported +// throughout 1.x. +func RPCDecompressor(dc Decompressor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.dc = dc + }) +} + +// MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. +// If this is not set, gRPC uses the default limit. +// +// Deprecated: use MaxRecvMsgSize instead. Will be supported throughout 1.x. +func MaxMsgSize(m int) ServerOption { + return MaxRecvMsgSize(m) +} + +// MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive. +// If this is not set, gRPC uses the default 4MB. +func MaxRecvMsgSize(m int) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.maxReceiveMessageSize = m + }) +} + +// MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send. +// If this is not set, gRPC uses the default `math.MaxInt32`. +func MaxSendMsgSize(m int) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.maxSendMessageSize = m + }) +} + +// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number +// of concurrent streams to each ServerTransport. +func MaxConcurrentStreams(n uint32) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.maxConcurrentStreams = n + }) +} + +// Creds returns a ServerOption that sets credentials for server connections. +func Creds(c credentials.TransportCredentials) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.creds = c + }) +} + +// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the +// server. Only one unary interceptor can be installed. The construction of multiple +// interceptors (e.g., chaining) can be implemented at the caller. +func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + if o.unaryInt != nil { + panic("The unary server interceptor was already set and may not be reset.") + } + o.unaryInt = i + }) +} + +// ChainUnaryInterceptor returns a ServerOption that specifies the chained interceptor +// for unary RPCs. The first interceptor will be the outer most, +// while the last interceptor will be the inner most wrapper around the real call. +// All unary interceptors added by this method will be chained. +func ChainUnaryInterceptor(interceptors ...UnaryServerInterceptor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.chainUnaryInts = append(o.chainUnaryInts, interceptors...) + }) +} + +// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the +// server. Only one stream interceptor can be installed. +func StreamInterceptor(i StreamServerInterceptor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + if o.streamInt != nil { + panic("The stream server interceptor was already set and may not be reset.") + } + o.streamInt = i + }) +} + +// ChainStreamInterceptor returns a ServerOption that specifies the chained interceptor +// for streaming RPCs. The first interceptor will be the outer most, +// while the last interceptor will be the inner most wrapper around the real call. +// All stream interceptors added by this method will be chained. +func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.chainStreamInts = append(o.chainStreamInts, interceptors...) + }) +} + +// InTapHandle returns a ServerOption that sets the tap handle for all the server +// transport to be created. Only one can be installed. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func InTapHandle(h tap.ServerInHandle) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + if o.inTapHandle != nil { + panic("The tap handle was already set and may not be reset.") + } + o.inTapHandle = h + }) +} + +// StatsHandler returns a ServerOption that sets the stats handler for the server. +func StatsHandler(h stats.Handler) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + if h == nil { + logger.Error("ignoring nil parameter in grpc.StatsHandler ServerOption") + // Do not allow a nil stats handler, which would otherwise cause + // panics. + return + } + o.statsHandlers = append(o.statsHandlers, h) + }) +} + +// binaryLogger returns a ServerOption that can set the binary logger for the +// server. +func binaryLogger(bl binarylog.Logger) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.binaryLogger = bl + }) +} + +// UnknownServiceHandler returns a ServerOption that allows for adding a custom +// unknown service handler. The provided method is a bidi-streaming RPC service +// handler that will be invoked instead of returning the "unimplemented" gRPC +// error whenever a request is received for an unregistered service or method. +// The handling function and stream interceptor (if set) have full access to +// the ServerStream, including its Context. +func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.unknownStreamDesc = &StreamDesc{ + StreamName: "unknown_service_handler", + Handler: streamHandler, + // We need to assume that the users of the streamHandler will want to use both. + ClientStreams: true, + ServerStreams: true, + } + }) +} + +// ConnectionTimeout returns a ServerOption that sets the timeout for +// connection establishment (up to and including HTTP/2 handshaking) for all +// new connections. If this is not set, the default is 120 seconds. A zero or +// negative value will result in an immediate timeout. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func ConnectionTimeout(d time.Duration) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.connectionTimeout = d + }) +} + +// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size +// of header list that the server is prepared to accept. +func MaxHeaderListSize(s uint32) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.maxHeaderListSize = &s + }) +} + +// HeaderTableSize returns a ServerOption that sets the size of dynamic +// header table for stream. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func HeaderTableSize(s uint32) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.headerTableSize = &s + }) +} + +// NumStreamWorkers returns a ServerOption that sets the number of worker +// goroutines that should be used to process incoming streams. Setting this to +// zero (default) will disable workers and spawn a new goroutine for each +// stream. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func NumStreamWorkers(numServerWorkers uint32) ServerOption { + // TODO: If/when this API gets stabilized (i.e. stream workers become the + // only way streams are processed), change the behavior of the zero value to + // a sane default. Preliminary experiments suggest that a value equal to the + // number of CPUs available is most performant; requires thorough testing. + return newFuncServerOption(func(o *serverOptions) { + o.numServerWorkers = numServerWorkers + }) +} + +// RecvBufferPool returns a ServerOption that configures the server +// to use the provided shared buffer pool for parsing incoming messages. Depending +// on the application's workload, this could result in reduced memory allocation. +// +// If you are unsure about how to implement a memory pool but want to utilize one, +// begin with grpc.NewSharedBufferPool. +// +// Note: The shared buffer pool feature will not be active if any of the following +// options are used: StatsHandler, EnableTracing, or binary logging. In such +// cases, the shared buffer pool will be ignored. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func RecvBufferPool(bufferPool SharedBufferPool) ServerOption { + return newFuncServerOption(func(o *serverOptions) { + o.recvBufferPool = bufferPool + }) +} + +// serverWorkerResetThreshold defines how often the stack must be reset. Every +// N requests, by spawning a new goroutine in its place, a worker can reset its +// stack so that large stacks don't live in memory forever. 2^16 should allow +// each goroutine stack to live for at least a few seconds in a typical +// workload (assuming a QPS of a few thousand requests/sec). +const serverWorkerResetThreshold = 1 << 16 + +// serverWorkers blocks on a *transport.Stream channel forever and waits for +// data to be fed by serveStreams. This allows multiple requests to be +// processed by the same goroutine, removing the need for expensive stack +// re-allocations (see the runtime.morestack problem [1]). +// +// [1] https://github.com/golang/go/issues/18138 +func (s *Server) serverWorker() { + for completed := 0; completed < serverWorkerResetThreshold; completed++ { + data, ok := <-s.serverWorkerChannel + if !ok { + return + } + s.handleSingleStream(data) + } + go s.serverWorker() +} + +func (s *Server) handleSingleStream(data *serverWorkerData) { + defer data.wg.Done() + s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream)) +} + +// initServerWorkers creates worker goroutines and a channel to process incoming +// connections to reduce the time spent overall on runtime.morestack. +func (s *Server) initServerWorkers() { + s.serverWorkerChannel = make(chan *serverWorkerData) + for i := uint32(0); i < s.opts.numServerWorkers; i++ { + go s.serverWorker() + } +} + +func (s *Server) stopServerWorkers() { + close(s.serverWorkerChannel) +} + +// NewServer creates a gRPC server which has no service registered and has not +// started to accept requests yet. +func NewServer(opt ...ServerOption) *Server { + opts := defaultServerOptions + for _, o := range globalServerOptions { + o.apply(&opts) + } + for _, o := range opt { + o.apply(&opts) + } + s := &Server{ + lis: make(map[net.Listener]bool), + opts: opts, + conns: make(map[string]map[transport.ServerTransport]bool), + services: make(map[string]*serviceInfo), + quit: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + czData: new(channelzData), + } + chainUnaryServerInterceptors(s) + chainStreamServerInterceptors(s) + s.cv = sync.NewCond(&s.mu) + if EnableTracing { + _, file, line, _ := runtime.Caller(1) + s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line)) + } + + if s.opts.numServerWorkers > 0 { + s.initServerWorkers() + } + + s.channelzID = channelz.RegisterServer(&channelzServer{s}, "") + channelz.Info(logger, s.channelzID, "Server created") + return s +} + +// printf records an event in s's event log, unless s has been stopped. +// REQUIRES s.mu is held. +func (s *Server) printf(format string, a ...any) { + if s.events != nil { + s.events.Printf(format, a...) + } +} + +// errorf records an error in s's event log, unless s has been stopped. +// REQUIRES s.mu is held. +func (s *Server) errorf(format string, a ...any) { + if s.events != nil { + s.events.Errorf(format, a...) + } +} + +// ServiceRegistrar wraps a single method that supports service registration. It +// enables users to pass concrete types other than grpc.Server to the service +// registration methods exported by the IDL generated code. +type ServiceRegistrar interface { + // RegisterService registers a service and its implementation to the + // concrete type implementing this interface. It may not be called + // once the server has started serving. + // desc describes the service and its methods and handlers. impl is the + // service implementation which is passed to the method handlers. + RegisterService(desc *ServiceDesc, impl any) +} + +// RegisterService registers a service and its implementation to the gRPC +// server. It is called from the IDL generated code. This must be called before +// invoking Serve. If ss is non-nil (for legacy code), its type is checked to +// ensure it implements sd.HandlerType. +func (s *Server) RegisterService(sd *ServiceDesc, ss any) { + if ss != nil { + ht := reflect.TypeOf(sd.HandlerType).Elem() + st := reflect.TypeOf(ss) + if !st.Implements(ht) { + logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht) + } + } + s.register(sd, ss) +} + +func (s *Server) register(sd *ServiceDesc, ss any) { + s.mu.Lock() + defer s.mu.Unlock() + s.printf("RegisterService(%q)", sd.ServiceName) + if s.serve { + logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName) + } + if _, ok := s.services[sd.ServiceName]; ok { + logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName) + } + info := &serviceInfo{ + serviceImpl: ss, + methods: make(map[string]*MethodDesc), + streams: make(map[string]*StreamDesc), + mdata: sd.Metadata, + } + for i := range sd.Methods { + d := &sd.Methods[i] + info.methods[d.MethodName] = d + } + for i := range sd.Streams { + d := &sd.Streams[i] + info.streams[d.StreamName] = d + } + s.services[sd.ServiceName] = info +} + +// MethodInfo contains the information of an RPC including its method name and type. +type MethodInfo struct { + // Name is the method name only, without the service name or package name. + Name string + // IsClientStream indicates whether the RPC is a client streaming RPC. + IsClientStream bool + // IsServerStream indicates whether the RPC is a server streaming RPC. + IsServerStream bool +} + +// ServiceInfo contains unary RPC method info, streaming RPC method info and metadata for a service. +type ServiceInfo struct { + Methods []MethodInfo + // Metadata is the metadata specified in ServiceDesc when registering service. + Metadata any +} + +// GetServiceInfo returns a map from service names to ServiceInfo. +// Service names include the package names, in the form of .. +func (s *Server) GetServiceInfo() map[string]ServiceInfo { + ret := make(map[string]ServiceInfo) + for n, srv := range s.services { + methods := make([]MethodInfo, 0, len(srv.methods)+len(srv.streams)) + for m := range srv.methods { + methods = append(methods, MethodInfo{ + Name: m, + IsClientStream: false, + IsServerStream: false, + }) + } + for m, d := range srv.streams { + methods = append(methods, MethodInfo{ + Name: m, + IsClientStream: d.ClientStreams, + IsServerStream: d.ServerStreams, + }) + } + + ret[n] = ServiceInfo{ + Methods: methods, + Metadata: srv.mdata, + } + } + return ret +} + +// ErrServerStopped indicates that the operation is now illegal because of +// the server being stopped. +var ErrServerStopped = errors.New("grpc: the server has been stopped") + +type listenSocket struct { + net.Listener + channelzID *channelz.Identifier +} + +func (l *listenSocket) ChannelzMetric() *channelz.SocketInternalMetric { + return &channelz.SocketInternalMetric{ + SocketOptions: channelz.GetSocketOption(l.Listener), + LocalAddr: l.Listener.Addr(), + } +} + +func (l *listenSocket) Close() error { + err := l.Listener.Close() + channelz.RemoveEntry(l.channelzID) + channelz.Info(logger, l.channelzID, "ListenSocket deleted") + return err +} + +// Serve accepts incoming connections on the listener lis, creating a new +// ServerTransport and service goroutine for each. The service goroutines +// read gRPC requests and then call the registered handlers to reply to them. +// Serve returns when lis.Accept fails with fatal errors. lis will be closed when +// this method returns. +// Serve will return a non-nil error unless Stop or GracefulStop is called. +func (s *Server) Serve(lis net.Listener) error { + s.mu.Lock() + s.printf("serving") + s.serve = true + if s.lis == nil { + // Serve called after Stop or GracefulStop. + s.mu.Unlock() + lis.Close() + return ErrServerStopped + } + + s.serveWG.Add(1) + defer func() { + s.serveWG.Done() + if s.quit.HasFired() { + // Stop or GracefulStop called; block until done and return nil. + <-s.done.Done() + } + }() + + ls := &listenSocket{Listener: lis} + s.lis[ls] = true + + defer func() { + s.mu.Lock() + if s.lis != nil && s.lis[ls] { + ls.Close() + delete(s.lis, ls) + } + s.mu.Unlock() + }() + + var err error + ls.channelzID, err = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String()) + if err != nil { + s.mu.Unlock() + return err + } + s.mu.Unlock() + channelz.Info(logger, ls.channelzID, "ListenSocket created") + + var tempDelay time.Duration // how long to sleep on accept failure + for { + rawConn, err := lis.Accept() + if err != nil { + if ne, ok := err.(interface { + Temporary() bool + }); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + s.mu.Lock() + s.printf("Accept error: %v; retrying in %v", err, tempDelay) + s.mu.Unlock() + timer := time.NewTimer(tempDelay) + select { + case <-timer.C: + case <-s.quit.Done(): + timer.Stop() + return nil + } + continue + } + s.mu.Lock() + s.printf("done serving; Accept = %v", err) + s.mu.Unlock() + + if s.quit.HasFired() { + return nil + } + return err + } + tempDelay = 0 + // Start a new goroutine to deal with rawConn so we don't stall this Accept + // loop goroutine. + // + // Make sure we account for the goroutine so GracefulStop doesn't nil out + // s.conns before this conn can be added. + s.serveWG.Add(1) + go func() { + s.handleRawConn(lis.Addr().String(), rawConn) + s.serveWG.Done() + }() + } +} + +// handleRawConn forks a goroutine to handle a just-accepted connection that +// has not had any I/O performed on it yet. +func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { + if s.quit.HasFired() { + rawConn.Close() + return + } + rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) + + // Finish handshaking (HTTP2) + st := s.newHTTP2Transport(rawConn) + rawConn.SetDeadline(time.Time{}) + if st == nil { + return + } + + if !s.addConn(lisAddr, st) { + return + } + go func() { + s.serveStreams(st) + s.removeConn(lisAddr, st) + }() +} + +func (s *Server) drainServerTransports(addr string) { + s.mu.Lock() + conns := s.conns[addr] + for st := range conns { + st.Drain("") + } + s.mu.Unlock() +} + +// newHTTP2Transport sets up a http/2 transport (using the +// gRPC http2 server transport in transport/http2_server.go). +func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { + config := &transport.ServerConfig{ + MaxStreams: s.opts.maxConcurrentStreams, + ConnectionTimeout: s.opts.connectionTimeout, + Credentials: s.opts.creds, + InTapHandle: s.opts.inTapHandle, + StatsHandlers: s.opts.statsHandlers, + KeepaliveParams: s.opts.keepaliveParams, + KeepalivePolicy: s.opts.keepalivePolicy, + InitialWindowSize: s.opts.initialWindowSize, + InitialConnWindowSize: s.opts.initialConnWindowSize, + WriteBufferSize: s.opts.writeBufferSize, + ReadBufferSize: s.opts.readBufferSize, + SharedWriteBuffer: s.opts.sharedWriteBuffer, + ChannelzParentID: s.channelzID, + MaxHeaderListSize: s.opts.maxHeaderListSize, + HeaderTableSize: s.opts.headerTableSize, + } + st, err := transport.NewServerTransport(c, config) + if err != nil { + s.mu.Lock() + s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) + s.mu.Unlock() + // ErrConnDispatched means that the connection was dispatched away from + // gRPC; those connections should be left open. + if err != credentials.ErrConnDispatched { + // Don't log on ErrConnDispatched and io.EOF to prevent log spam. + if err != io.EOF { + channelz.Info(logger, s.channelzID, "grpc: Server.Serve failed to create ServerTransport: ", err) + } + c.Close() + } + return nil + } + + return st +} + +func (s *Server) serveStreams(st transport.ServerTransport) { + defer st.Close(errors.New("finished serving streams for the server transport")) + var wg sync.WaitGroup + + st.HandleStreams(func(stream *transport.Stream) { + wg.Add(1) + if s.opts.numServerWorkers > 0 { + data := &serverWorkerData{st: st, wg: &wg, stream: stream} + select { + case s.serverWorkerChannel <- data: + return + default: + // If all stream workers are busy, fallback to the default code path. + } + } + go func() { + defer wg.Done() + s.handleStream(st, stream, s.traceInfo(st, stream)) + }() + }, func(ctx context.Context, method string) context.Context { + if !EnableTracing { + return ctx + } + tr := trace.New("grpc.Recv."+methodFamily(method), method) + return trace.NewContext(ctx, tr) + }) + wg.Wait() +} + +var _ http.Handler = (*Server)(nil) + +// ServeHTTP implements the Go standard library's http.Handler +// interface by responding to the gRPC request r, by looking up +// the requested gRPC method in the gRPC server s. +// +// The provided HTTP request must have arrived on an HTTP/2 +// connection. When using the Go standard library's server, +// practically this means that the Request must also have arrived +// over TLS. +// +// To share one port (such as 443 for https) between gRPC and an +// existing http.Handler, use a root http.Handler such as: +// +// if r.ProtoMajor == 2 && strings.HasPrefix( +// r.Header.Get("Content-Type"), "application/grpc") { +// grpcServer.ServeHTTP(w, r) +// } else { +// yourMux.ServeHTTP(w, r) +// } +// +// Note that ServeHTTP uses Go's HTTP/2 server implementation which is totally +// separate from grpc-go's HTTP/2 server. Performance and features may vary +// between the two paths. ServeHTTP does not support some gRPC features +// available through grpc-go's HTTP/2 server. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers) + if err != nil { + // Errors returned from transport.NewServerHandlerTransport have + // already been written to w. + return + } + if !s.addConn(listenerAddressForServeHTTP, st) { + return + } + defer s.removeConn(listenerAddressForServeHTTP, st) + s.serveStreams(st) +} + +// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. +// If tracing is not enabled, it returns nil. +func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { + if !EnableTracing { + return nil + } + tr, ok := trace.FromContext(stream.Context()) + if !ok { + return nil + } + + trInfo = &traceInfo{ + tr: tr, + firstLine: firstLine{ + client: false, + remoteAddr: st.RemoteAddr(), + }, + } + if dl, ok := stream.Context().Deadline(); ok { + trInfo.firstLine.deadline = time.Until(dl) + } + return trInfo +} + +func (s *Server) addConn(addr string, st transport.ServerTransport) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.conns == nil { + st.Close(errors.New("Server.addConn called when server has already been stopped")) + return false + } + if s.drain { + // Transport added after we drained our existing conns: drain it + // immediately. + st.Drain("") + } + + if s.conns[addr] == nil { + // Create a map entry if this is the first connection on this listener. + s.conns[addr] = make(map[transport.ServerTransport]bool) + } + s.conns[addr][st] = true + return true +} + +func (s *Server) removeConn(addr string, st transport.ServerTransport) { + s.mu.Lock() + defer s.mu.Unlock() + + conns := s.conns[addr] + if conns != nil { + delete(conns, st) + if len(conns) == 0 { + // If the last connection for this address is being removed, also + // remove the map entry corresponding to the address. This is used + // in GracefulStop() when waiting for all connections to be closed. + delete(s.conns, addr) + } + s.cv.Broadcast() + } +} + +func (s *Server) channelzMetric() *channelz.ServerInternalMetric { + return &channelz.ServerInternalMetric{ + CallsStarted: atomic.LoadInt64(&s.czData.callsStarted), + CallsSucceeded: atomic.LoadInt64(&s.czData.callsSucceeded), + CallsFailed: atomic.LoadInt64(&s.czData.callsFailed), + LastCallStartedTimestamp: time.Unix(0, atomic.LoadInt64(&s.czData.lastCallStartedTime)), + } +} + +func (s *Server) incrCallsStarted() { + atomic.AddInt64(&s.czData.callsStarted, 1) + atomic.StoreInt64(&s.czData.lastCallStartedTime, time.Now().UnixNano()) +} + +func (s *Server) incrCallsSucceeded() { + atomic.AddInt64(&s.czData.callsSucceeded, 1) +} + +func (s *Server) incrCallsFailed() { + atomic.AddInt64(&s.czData.callsFailed, 1) +} + +func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg any, cp Compressor, opts *transport.Options, comp encoding.Compressor) error { + data, err := encode(s.getCodec(stream.ContentSubtype()), msg) + if err != nil { + channelz.Error(logger, s.channelzID, "grpc: server failed to encode response: ", err) + return err + } + compData, err := compress(data, cp, comp) + if err != nil { + channelz.Error(logger, s.channelzID, "grpc: server failed to compress response: ", err) + return err + } + hdr, payload := msgHeader(data, compData) + // TODO(dfawley): should we be checking len(data) instead? + if len(payload) > s.opts.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize) + } + err = t.Write(stream, hdr, payload, opts) + if err == nil { + for _, sh := range s.opts.statsHandlers { + sh.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now())) + } + } + return err +} + +// chainUnaryServerInterceptors chains all unary server interceptors into one. +func chainUnaryServerInterceptors(s *Server) { + // Prepend opts.unaryInt to the chaining interceptors if it exists, since unaryInt will + // be executed before any other chained interceptors. + interceptors := s.opts.chainUnaryInts + if s.opts.unaryInt != nil { + interceptors = append([]UnaryServerInterceptor{s.opts.unaryInt}, s.opts.chainUnaryInts...) + } + + var chainedInt UnaryServerInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = chainUnaryInterceptors(interceptors) + } + + s.opts.unaryInt = chainedInt +} + +func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { + return func(ctx context.Context, req any, info *UnaryServerInfo, handler UnaryHandler) (any, error) { + return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) + } +} + +func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(ctx context.Context, req any) (any, error) { + return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) + } +} + +func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { + shs := s.opts.statsHandlers + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { + if channelz.IsOn() { + s.incrCallsStarted() + } + var statsBegin *stats.Begin + for _, sh := range shs { + beginTime := time.Now() + statsBegin = &stats.Begin{ + BeginTime: beginTime, + IsClientStream: false, + IsServerStream: false, + } + sh.HandleRPC(stream.Context(), statsBegin) + } + if trInfo != nil { + trInfo.tr.LazyLog(&trInfo.firstLine, false) + } + // The deferred error handling for tracing, stats handler and channelz are + // combined into one function to reduce stack usage -- a defer takes ~56-64 + // bytes on the stack, so overflowing the stack will require a stack + // re-allocation, which is expensive. + // + // To maintain behavior similar to separate deferred statements, statements + // should be executed in the reverse order. That is, tracing first, stats + // handler second, and channelz last. Note that panics *within* defers will + // lead to different behavior, but that's an acceptable compromise; that + // would be undefined behavior territory anyway. + defer func() { + if trInfo != nil { + if err != nil && err != io.EOF { + trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) + trInfo.tr.SetError() + } + trInfo.tr.Finish() + } + + for _, sh := range shs { + end := &stats.End{ + BeginTime: statsBegin.BeginTime, + EndTime: time.Now(), + } + if err != nil && err != io.EOF { + end.Error = toRPCErr(err) + } + sh.HandleRPC(stream.Context(), end) + } + + if channelz.IsOn() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() + } + } + }() + } + var binlogs []binarylog.MethodLogger + if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { + binlogs = append(binlogs, ml) + } + if s.opts.binaryLogger != nil { + if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { + binlogs = append(binlogs, ml) + } + } + if len(binlogs) != 0 { + ctx := stream.Context() + md, _ := metadata.FromIncomingContext(ctx) + logEntry := &binarylog.ClientHeader{ + Header: md, + MethodName: stream.Method(), + PeerAddr: nil, + } + if deadline, ok := ctx.Deadline(); ok { + logEntry.Timeout = time.Until(deadline) + if logEntry.Timeout < 0 { + logEntry.Timeout = 0 + } + } + if a := md[":authority"]; len(a) > 0 { + logEntry.Authority = a[0] + } + if peer, ok := peer.FromContext(ctx); ok { + logEntry.PeerAddr = peer.Addr + } + for _, binlog := range binlogs { + binlog.Log(ctx, logEntry) + } + } + + // comp and cp are used for compression. decomp and dc are used for + // decompression. If comp and decomp are both set, they are the same; + // however they are kept separate to ensure that at most one of the + // compressor/decompressor variable pairs are set for use later. + var comp, decomp encoding.Compressor + var cp Compressor + var dc Decompressor + var sendCompressorName string + + // If dc is set and matches the stream's compression, use it. Otherwise, try + // to find a matching registered compressor for decomp. + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + dc = s.opts.dc + } else if rc != "" && rc != encoding.Identity { + decomp = encoding.GetCompressor(rc) + if decomp == nil { + st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) + t.WriteStatus(stream, st) + return st.Err() + } + } + + // If cp is set, use it. Otherwise, attempt to compress the response using + // the incoming message compression method. + // + // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. + if s.opts.cp != nil { + cp = s.opts.cp + sendCompressorName = cp.Type() + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + // Legacy compressor not specified; attempt to respond with same encoding. + comp = encoding.GetCompressor(rc) + if comp != nil { + sendCompressorName = comp.Name() + } + } + + if sendCompressorName != "" { + if err := stream.SetSendCompress(sendCompressorName); err != nil { + return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) + } + } + + var payInfo *payloadInfo + if len(shs) != 0 || len(binlogs) != 0 { + payInfo = &payloadInfo{} + } + d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp) + if err != nil { + if e := t.WriteStatus(stream, status.Convert(err)); e != nil { + channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) + } + return err + } + if channelz.IsOn() { + t.IncrMsgRecv() + } + df := func(v any) error { + if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil { + return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) + } + for _, sh := range shs { + sh.HandleRPC(stream.Context(), &stats.InPayload{ + RecvTime: time.Now(), + Payload: v, + Length: len(d), + WireLength: payInfo.compressedLength + headerLen, + CompressedLength: payInfo.compressedLength, + Data: d, + }) + } + if len(binlogs) != 0 { + cm := &binarylog.ClientMessage{ + Message: d, + } + for _, binlog := range binlogs { + binlog.Log(stream.Context(), cm) + } + } + if trInfo != nil { + trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) + } + return nil + } + ctx := NewContextWithServerTransportStream(stream.Context(), stream) + reply, appErr := md.Handler(info.serviceImpl, ctx, df, s.opts.unaryInt) + if appErr != nil { + appStatus, ok := status.FromError(appErr) + if !ok { + // Convert non-status application error to a status error with code + // Unknown, but handle context errors specifically. + appStatus = status.FromContextError(appErr) + appErr = appStatus.Err() + } + if trInfo != nil { + trInfo.tr.LazyLog(stringer(appStatus.Message()), true) + trInfo.tr.SetError() + } + if e := t.WriteStatus(stream, appStatus); e != nil { + channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) + } + if len(binlogs) != 0 { + if h, _ := stream.Header(); h.Len() > 0 { + // Only log serverHeader if there was header. Otherwise it can + // be trailer only. + sh := &binarylog.ServerHeader{ + Header: h, + } + for _, binlog := range binlogs { + binlog.Log(stream.Context(), sh) + } + } + st := &binarylog.ServerTrailer{ + Trailer: stream.Trailer(), + Err: appErr, + } + for _, binlog := range binlogs { + binlog.Log(stream.Context(), st) + } + } + return appErr + } + if trInfo != nil { + trInfo.tr.LazyLog(stringer("OK"), false) + } + opts := &transport.Options{Last: true} + + // Server handler could have set new compressor by calling SetSendCompressor. + // In case it is set, we need to use it for compressing outbound message. + if stream.SendCompress() != sendCompressorName { + comp = encoding.GetCompressor(stream.SendCompress()) + } + if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil { + if err == io.EOF { + // The entire stream is done (for unary RPC only). + return err + } + if sts, ok := status.FromError(err); ok { + if e := t.WriteStatus(stream, sts); e != nil { + channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e) + } + } else { + switch st := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + default: + panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st)) + } + } + if len(binlogs) != 0 { + h, _ := stream.Header() + sh := &binarylog.ServerHeader{ + Header: h, + } + st := &binarylog.ServerTrailer{ + Trailer: stream.Trailer(), + Err: appErr, + } + for _, binlog := range binlogs { + binlog.Log(stream.Context(), sh) + binlog.Log(stream.Context(), st) + } + } + return err + } + if len(binlogs) != 0 { + h, _ := stream.Header() + sh := &binarylog.ServerHeader{ + Header: h, + } + sm := &binarylog.ServerMessage{ + Message: reply, + } + for _, binlog := range binlogs { + binlog.Log(stream.Context(), sh) + binlog.Log(stream.Context(), sm) + } + } + if channelz.IsOn() { + t.IncrMsgSent() + } + if trInfo != nil { + trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) + } + // TODO: Should we be logging if writing status failed here, like above? + // Should the logging be in WriteStatus? Should we ignore the WriteStatus + // error or allow the stats handler to see it? + if len(binlogs) != 0 { + st := &binarylog.ServerTrailer{ + Trailer: stream.Trailer(), + Err: appErr, + } + for _, binlog := range binlogs { + binlog.Log(stream.Context(), st) + } + } + return t.WriteStatus(stream, statusOK) +} + +// chainStreamServerInterceptors chains all stream server interceptors into one. +func chainStreamServerInterceptors(s *Server) { + // Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will + // be executed before any other chained interceptors. + interceptors := s.opts.chainStreamInts + if s.opts.streamInt != nil { + interceptors = append([]StreamServerInterceptor{s.opts.streamInt}, s.opts.chainStreamInts...) + } + + var chainedInt StreamServerInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = chainStreamInterceptors(interceptors) + } + + s.opts.streamInt = chainedInt +} + +func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { + return func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { + return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) + } +} + +func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(srv any, stream ServerStream) error { + return interceptors[curr+1](srv, stream, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) + } +} + +func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) { + if channelz.IsOn() { + s.incrCallsStarted() + } + shs := s.opts.statsHandlers + var statsBegin *stats.Begin + if len(shs) != 0 { + beginTime := time.Now() + statsBegin = &stats.Begin{ + BeginTime: beginTime, + IsClientStream: sd.ClientStreams, + IsServerStream: sd.ServerStreams, + } + for _, sh := range shs { + sh.HandleRPC(stream.Context(), statsBegin) + } + } + ctx := NewContextWithServerTransportStream(stream.Context(), stream) + ss := &serverStream{ + ctx: ctx, + t: t, + s: stream, + p: &parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, + codec: s.getCodec(stream.ContentSubtype()), + maxReceiveMessageSize: s.opts.maxReceiveMessageSize, + maxSendMessageSize: s.opts.maxSendMessageSize, + trInfo: trInfo, + statsHandler: shs, + } + + if len(shs) != 0 || trInfo != nil || channelz.IsOn() { + // See comment in processUnaryRPC on defers. + defer func() { + if trInfo != nil { + ss.mu.Lock() + if err != nil && err != io.EOF { + ss.trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) + ss.trInfo.tr.SetError() + } + ss.trInfo.tr.Finish() + ss.trInfo.tr = nil + ss.mu.Unlock() + } + + if len(shs) != 0 { + end := &stats.End{ + BeginTime: statsBegin.BeginTime, + EndTime: time.Now(), + } + if err != nil && err != io.EOF { + end.Error = toRPCErr(err) + } + for _, sh := range shs { + sh.HandleRPC(stream.Context(), end) + } + } + + if channelz.IsOn() { + if err != nil && err != io.EOF { + s.incrCallsFailed() + } else { + s.incrCallsSucceeded() + } + } + }() + } + + if ml := binarylog.GetMethodLogger(stream.Method()); ml != nil { + ss.binlogs = append(ss.binlogs, ml) + } + if s.opts.binaryLogger != nil { + if ml := s.opts.binaryLogger.GetMethodLogger(stream.Method()); ml != nil { + ss.binlogs = append(ss.binlogs, ml) + } + } + if len(ss.binlogs) != 0 { + md, _ := metadata.FromIncomingContext(ctx) + logEntry := &binarylog.ClientHeader{ + Header: md, + MethodName: stream.Method(), + PeerAddr: nil, + } + if deadline, ok := ctx.Deadline(); ok { + logEntry.Timeout = time.Until(deadline) + if logEntry.Timeout < 0 { + logEntry.Timeout = 0 + } + } + if a := md[":authority"]; len(a) > 0 { + logEntry.Authority = a[0] + } + if peer, ok := peer.FromContext(ss.Context()); ok { + logEntry.PeerAddr = peer.Addr + } + for _, binlog := range ss.binlogs { + binlog.Log(stream.Context(), logEntry) + } + } + + // If dc is set and matches the stream's compression, use it. Otherwise, try + // to find a matching registered compressor for decomp. + if rc := stream.RecvCompress(); s.opts.dc != nil && s.opts.dc.Type() == rc { + ss.dc = s.opts.dc + } else if rc != "" && rc != encoding.Identity { + ss.decomp = encoding.GetCompressor(rc) + if ss.decomp == nil { + st := status.Newf(codes.Unimplemented, "grpc: Decompressor is not installed for grpc-encoding %q", rc) + t.WriteStatus(ss.s, st) + return st.Err() + } + } + + // If cp is set, use it. Otherwise, attempt to compress the response using + // the incoming message compression method. + // + // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. + if s.opts.cp != nil { + ss.cp = s.opts.cp + ss.sendCompressorName = s.opts.cp.Type() + } else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity { + // Legacy compressor not specified; attempt to respond with same encoding. + ss.comp = encoding.GetCompressor(rc) + if ss.comp != nil { + ss.sendCompressorName = rc + } + } + + if ss.sendCompressorName != "" { + if err := stream.SetSendCompress(ss.sendCompressorName); err != nil { + return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err) + } + } + + ss.ctx = newContextWithRPCInfo(ss.ctx, false, ss.codec, ss.cp, ss.comp) + + if trInfo != nil { + trInfo.tr.LazyLog(&trInfo.firstLine, false) + } + var appErr error + var server any + if info != nil { + server = info.serviceImpl + } + if s.opts.streamInt == nil { + appErr = sd.Handler(server, ss) + } else { + info := &StreamServerInfo{ + FullMethod: stream.Method(), + IsClientStream: sd.ClientStreams, + IsServerStream: sd.ServerStreams, + } + appErr = s.opts.streamInt(server, ss, info, sd.Handler) + } + if appErr != nil { + appStatus, ok := status.FromError(appErr) + if !ok { + // Convert non-status application error to a status error with code + // Unknown, but handle context errors specifically. + appStatus = status.FromContextError(appErr) + appErr = appStatus.Err() + } + if trInfo != nil { + ss.mu.Lock() + ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) + ss.trInfo.tr.SetError() + ss.mu.Unlock() + } + if len(ss.binlogs) != 0 { + st := &binarylog.ServerTrailer{ + Trailer: ss.s.Trailer(), + Err: appErr, + } + for _, binlog := range ss.binlogs { + binlog.Log(stream.Context(), st) + } + } + t.WriteStatus(ss.s, appStatus) + // TODO: Should we log an error from WriteStatus here and below? + return appErr + } + if trInfo != nil { + ss.mu.Lock() + ss.trInfo.tr.LazyLog(stringer("OK"), false) + ss.mu.Unlock() + } + if len(ss.binlogs) != 0 { + st := &binarylog.ServerTrailer{ + Trailer: ss.s.Trailer(), + Err: appErr, + } + for _, binlog := range ss.binlogs { + binlog.Log(stream.Context(), st) + } + } + return t.WriteStatus(ss.s, statusOK) +} + +func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { + sm := stream.Method() + if sm != "" && sm[0] == '/' { + sm = sm[1:] + } + pos := strings.LastIndex(sm, "/") + if pos == -1 { + if trInfo != nil { + trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true) + trInfo.tr.SetError() + } + errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { + if trInfo != nil { + trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) + trInfo.tr.SetError() + } + channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err) + } + if trInfo != nil { + trInfo.tr.Finish() + } + return + } + service := sm[:pos] + method := sm[pos+1:] + + srv, knownService := s.services[service] + if knownService { + if md, ok := srv.methods[method]; ok { + s.processUnaryRPC(t, stream, srv, md, trInfo) + return + } + if sd, ok := srv.streams[method]; ok { + s.processStreamingRPC(t, stream, srv, sd, trInfo) + return + } + } + // Unknown service, or known server unknown method. + if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil { + s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo) + return + } + var errDesc string + if !knownService { + errDesc = fmt.Sprintf("unknown service %v", service) + } else { + errDesc = fmt.Sprintf("unknown method %v for service %v", method, service) + } + if trInfo != nil { + trInfo.tr.LazyPrintf("%s", errDesc) + trInfo.tr.SetError() + } + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { + if trInfo != nil { + trInfo.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) + trInfo.tr.SetError() + } + channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err) + } + if trInfo != nil { + trInfo.tr.Finish() + } +} + +// The key to save ServerTransportStream in the context. +type streamKey struct{} + +// NewContextWithServerTransportStream creates a new context from ctx and +// attaches stream to it. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context { + return context.WithValue(ctx, streamKey{}, stream) +} + +// ServerTransportStream is a minimal interface that a transport stream must +// implement. This can be used to mock an actual transport stream for tests of +// handler code that use, for example, grpc.SetHeader (which requires some +// stream to be in context). +// +// See also NewContextWithServerTransportStream. +// +// # Experimental +// +// Notice: This type is EXPERIMENTAL and may be changed or removed in a +// later release. +type ServerTransportStream interface { + Method() string + SetHeader(md metadata.MD) error + SendHeader(md metadata.MD) error + SetTrailer(md metadata.MD) error +} + +// ServerTransportStreamFromContext returns the ServerTransportStream saved in +// ctx. Returns nil if the given context has no stream associated with it +// (which implies it is not an RPC invocation context). +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream { + s, _ := ctx.Value(streamKey{}).(ServerTransportStream) + return s +} + +// Stop stops the gRPC server. It immediately closes all open +// connections and listeners. +// It cancels all active RPCs on the server side and the corresponding +// pending RPCs on the client side will get notified by connection +// errors. +func (s *Server) Stop() { + s.quit.Fire() + + defer func() { + s.serveWG.Wait() + s.done.Fire() + }() + + s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) + + s.mu.Lock() + listeners := s.lis + s.lis = nil + conns := s.conns + s.conns = nil + // interrupt GracefulStop if Stop and GracefulStop are called concurrently. + s.cv.Broadcast() + s.mu.Unlock() + + for lis := range listeners { + lis.Close() + } + for _, cs := range conns { + for st := range cs { + st.Close(errors.New("Server.Stop called")) + } + } + if s.opts.numServerWorkers > 0 { + s.stopServerWorkers() + } + + s.mu.Lock() + if s.events != nil { + s.events.Finish() + s.events = nil + } + s.mu.Unlock() +} + +// GracefulStop stops the gRPC server gracefully. It stops the server from +// accepting new connections and RPCs and blocks until all the pending RPCs are +// finished. +func (s *Server) GracefulStop() { + s.quit.Fire() + defer s.done.Fire() + + s.channelzRemoveOnce.Do(func() { channelz.RemoveEntry(s.channelzID) }) + s.mu.Lock() + if s.conns == nil { + s.mu.Unlock() + return + } + + for lis := range s.lis { + lis.Close() + } + s.lis = nil + if !s.drain { + for _, conns := range s.conns { + for st := range conns { + st.Drain("graceful_stop") + } + } + s.drain = true + } + + // Wait for serving threads to be ready to exit. Only then can we be sure no + // new conns will be created. + s.mu.Unlock() + s.serveWG.Wait() + s.mu.Lock() + + for len(s.conns) != 0 { + s.cv.Wait() + } + s.conns = nil + if s.events != nil { + s.events.Finish() + s.events = nil + } + s.mu.Unlock() +} + +// contentSubtype must be lowercase +// cannot return nil +func (s *Server) getCodec(contentSubtype string) baseCodec { + if s.opts.codec != nil { + return s.opts.codec + } + if contentSubtype == "" { + return encoding.GetCodec(proto.Name) + } + codec := encoding.GetCodec(contentSubtype) + if codec == nil { + return encoding.GetCodec(proto.Name) + } + return codec +} + +// SetHeader sets the header metadata to be sent from the server to the client. +// The context provided must be the context passed to the server's handler. +// +// Streaming RPCs should prefer the SetHeader method of the ServerStream. +// +// When called multiple times, all the provided metadata will be merged. All +// the metadata will be sent out when one of the following happens: +// +// - grpc.SendHeader is called, or for streaming handlers, stream.SendHeader. +// - The first response message is sent. For unary handlers, this occurs when +// the handler returns; for streaming handlers, this can happen when stream's +// SendMsg method is called. +// - An RPC status is sent out (error or success). This occurs when the handler +// returns. +// +// SetHeader will fail if called after any of the events above. +// +// The error returned is compatible with the status package. However, the +// status code will often not match the RPC status as seen by the client +// application, and therefore, should not be relied upon for this purpose. +func SetHeader(ctx context.Context, md metadata.MD) error { + if md.Len() == 0 { + return nil + } + stream := ServerTransportStreamFromContext(ctx) + if stream == nil { + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + } + return stream.SetHeader(md) +} + +// SendHeader sends header metadata. It may be called at most once, and may not +// be called after any event that causes headers to be sent (see SetHeader for +// a complete list). The provided md and headers set by SetHeader() will be +// sent. +// +// The error returned is compatible with the status package. However, the +// status code will often not match the RPC status as seen by the client +// application, and therefore, should not be relied upon for this purpose. +func SendHeader(ctx context.Context, md metadata.MD) error { + stream := ServerTransportStreamFromContext(ctx) + if stream == nil { + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + } + if err := stream.SendHeader(md); err != nil { + return toRPCErr(err) + } + return nil +} + +// SetSendCompressor sets a compressor for outbound messages from the server. +// It must not be called after any event that causes headers to be sent +// (see ServerStream.SetHeader for the complete list). Provided compressor is +// used when below conditions are met: +// +// - compressor is registered via encoding.RegisterCompressor +// - compressor name must exist in the client advertised compressor names +// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to +// get client supported compressor names. +// +// The context provided must be the context passed to the server's handler. +// It must be noted that compressor name encoding.Identity disables the +// outbound compression. +// By default, server messages will be sent using the same compressor with +// which request messages were sent. +// +// It is not safe to call SetSendCompressor concurrently with SendHeader and +// SendMsg. +// +// # Experimental +// +// Notice: This function is EXPERIMENTAL and may be changed or removed in a +// later release. +func SetSendCompressor(ctx context.Context, name string) error { + stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) + if !ok || stream == nil { + return fmt.Errorf("failed to fetch the stream from the given context") + } + + if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil { + return fmt.Errorf("unable to set send compressor: %w", err) + } + + return stream.SetSendCompress(name) +} + +// ClientSupportedCompressors returns compressor names advertised by the client +// via grpc-accept-encoding header. +// +// The context provided must be the context passed to the server's handler. +// +// # Experimental +// +// Notice: This function is EXPERIMENTAL and may be changed or removed in a +// later release. +func ClientSupportedCompressors(ctx context.Context) ([]string, error) { + stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream) + if !ok || stream == nil { + return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx) + } + + return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil +} + +// SetTrailer sets the trailer metadata that will be sent when an RPC returns. +// When called more than once, all the provided metadata will be merged. +// +// The error returned is compatible with the status package. However, the +// status code will often not match the RPC status as seen by the client +// application, and therefore, should not be relied upon for this purpose. +func SetTrailer(ctx context.Context, md metadata.MD) error { + if md.Len() == 0 { + return nil + } + stream := ServerTransportStreamFromContext(ctx) + if stream == nil { + return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + } + return stream.SetTrailer(md) +} + +// Method returns the method string for the server context. The returned +// string is in the format of "/service/method". +func Method(ctx context.Context) (string, bool) { + s := ServerTransportStreamFromContext(ctx) + if s == nil { + return "", false + } + return s.Method(), true +} + +type channelzServer struct { + s *Server +} + +func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric { + return c.s.channelzMetric() +} + +// validateSendCompressor returns an error when given compressor name cannot be +// handled by the server or the client based on the advertised compressors. +func validateSendCompressor(name, clientCompressors string) error { + if name == encoding.Identity { + return nil + } + + if !grpcutil.IsCompressorNameRegistered(name) { + return fmt.Errorf("compressor not registered %q", name) + } + + for _, c := range strings.Split(clientCompressors, ",") { + if c == name { + return nil // found match + } + } + return fmt.Errorf("client does not support compressor %q", name) +} diff --git a/vendor/google.golang.org/grpc/version.go b/vendor/google.golang.org/grpc/version.go new file mode 100644 index 0000000000..d3f5bcbfce --- /dev/null +++ b/vendor/google.golang.org/grpc/version.go @@ -0,0 +1,22 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +// Version is the current grpc version. +const Version = "1.58.2" diff --git a/vendor/modules.txt b/vendor/modules.txt new file mode 100644 index 0000000000..484309c5cb --- /dev/null +++ b/vendor/modules.txt @@ -0,0 +1,1281 @@ +# github.com/0x6flab/namegenerator v1.1.0 +## explicit; go 1.19 +github.com/0x6flab/namegenerator +# github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 +## explicit; go 1.16 +github.com/Azure/go-ansiterm +github.com/Azure/go-ansiterm/winterm +# github.com/BurntSushi/toml v1.3.2 +## explicit; go 1.16 +github.com/BurntSushi/toml +github.com/BurntSushi/toml/internal +# github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53 +## explicit; go 1.13 +github.com/CloudyKit/fastprinter +# github.com/CloudyKit/jet/v6 v6.2.0 +## explicit; go 1.12 +github.com/CloudyKit/jet/v6 +# github.com/Joker/jade v1.1.3 +## explicit; go 1.14 +github.com/Joker/jade +# github.com/Microsoft/go-winio v0.6.1 +## explicit; go 1.17 +github.com/Microsoft/go-winio +github.com/Microsoft/go-winio/internal/fs +github.com/Microsoft/go-winio/internal/socket +github.com/Microsoft/go-winio/internal/stringbuffer +github.com/Microsoft/go-winio/pkg/guid +# github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 +## explicit +github.com/Nvveen/Gotty +# github.com/Shopify/goreferrer v0.0.0-20220729165902-8cddb4f5de06 +## explicit; go 1.17 +github.com/Shopify/goreferrer +# github.com/andybalholm/brotli v1.0.5 +## explicit; go 1.12 +github.com/andybalholm/brotli +# github.com/apapsch/go-jsonmerge/v2 v2.0.0 +## explicit; go 1.12 +github.com/apapsch/go-jsonmerge/v2 +# github.com/authzed/authzed-go v0.10.0 +## explicit; go 1.18 +github.com/authzed/authzed-go/proto/authzed/api/v1 +github.com/authzed/authzed-go/v1 +# github.com/authzed/grpcutil v0.0.0-20230908193239-4286bb1d6403 +## explicit; go 1.20 +github.com/authzed/grpcutil +# github.com/aymerick/douceur v0.2.0 +## explicit +github.com/aymerick/douceur/css +github.com/aymerick/douceur/parser +# github.com/beorn7/perks v1.0.1 +## explicit; go 1.11 +github.com/beorn7/perks/quantile +# github.com/bytedance/sonic v1.10.2 +## explicit; go 1.16 +github.com/bytedance/sonic +github.com/bytedance/sonic/ast +github.com/bytedance/sonic/decoder +github.com/bytedance/sonic/encoder +github.com/bytedance/sonic/internal/abi +github.com/bytedance/sonic/internal/caching +github.com/bytedance/sonic/internal/cpu +github.com/bytedance/sonic/internal/decoder +github.com/bytedance/sonic/internal/encoder +github.com/bytedance/sonic/internal/jit +github.com/bytedance/sonic/internal/native +github.com/bytedance/sonic/internal/native/avx +github.com/bytedance/sonic/internal/native/avx2 +github.com/bytedance/sonic/internal/native/sse +github.com/bytedance/sonic/internal/native/types +github.com/bytedance/sonic/internal/resolver +github.com/bytedance/sonic/internal/rt +github.com/bytedance/sonic/loader +github.com/bytedance/sonic/option +github.com/bytedance/sonic/unquote +github.com/bytedance/sonic/utf8 +# github.com/caarlos0/env/v7 v7.1.0 +## explicit; go 1.17 +github.com/caarlos0/env/v7 +# github.com/cenkalti/backoff/v3 v3.2.2 +## explicit; go 1.12 +github.com/cenkalti/backoff/v3 +# github.com/cenkalti/backoff/v4 v4.2.1 +## explicit; go 1.18 +github.com/cenkalti/backoff/v4 +# github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d +## explicit; go 1.12 +github.com/certifi/gocertifi +# github.com/cespare/xxhash/v2 v2.2.0 +## explicit; go 1.11 +github.com/cespare/xxhash/v2 +# github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d +## explicit; go 1.16 +github.com/chenzhuoyu/base64x +# github.com/chenzhuoyu/iasm v0.9.0 +## explicit; go 1.16 +github.com/chenzhuoyu/iasm/expr +github.com/chenzhuoyu/iasm/x86_64 +# github.com/containerd/continuity v0.4.2 +## explicit; go 1.19 +github.com/containerd/continuity/pathdriver +# github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc +## explicit +github.com/davecgh/go-spew/spew +# github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 +## explicit; go 1.17 +github.com/decred/dcrd/dcrec/secp256k1/v4 +# github.com/deepmap/oapi-codegen v1.15.0 +## explicit; go 1.20 +github.com/deepmap/oapi-codegen/pkg/runtime +github.com/deepmap/oapi-codegen/pkg/types +# github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f +## explicit +github.com/dgryski/go-rendezvous +# github.com/docker/cli v24.0.6+incompatible +## explicit +github.com/docker/cli/cli/compose/interpolation +github.com/docker/cli/cli/compose/loader +github.com/docker/cli/cli/compose/schema +github.com/docker/cli/cli/compose/template +github.com/docker/cli/cli/compose/types +github.com/docker/cli/opts +# github.com/docker/docker v24.0.6+incompatible +## explicit +github.com/docker/docker/api/types/blkiodev +github.com/docker/docker/api/types/container +github.com/docker/docker/api/types/filters +github.com/docker/docker/api/types/mount +github.com/docker/docker/api/types/network +github.com/docker/docker/api/types/strslice +github.com/docker/docker/api/types/swarm +github.com/docker/docker/api/types/swarm/runtime +github.com/docker/docker/api/types/versions +github.com/docker/docker/pkg/namesgenerator +# github.com/docker/go-connections v0.4.0 +## explicit +github.com/docker/go-connections/nat +# github.com/docker/go-units v0.5.0 +## explicit +github.com/docker/go-units +# github.com/dsnet/golib/memfile v1.0.0 +## explicit; go 1.12 +github.com/dsnet/golib/memfile +# github.com/eclipse/paho.mqtt.golang v1.4.3 +## explicit; go 1.18 +github.com/eclipse/paho.mqtt.golang +github.com/eclipse/paho.mqtt.golang/packets +# github.com/envoyproxy/protoc-gen-validate v1.0.2 +## explicit; go 1.19 +github.com/envoyproxy/protoc-gen-validate/validate +# github.com/fatih/color v1.15.0 +## explicit; go 1.17 +github.com/fatih/color +# github.com/fatih/structs v1.1.0 +## explicit +github.com/fatih/structs +# github.com/felixge/httpsnoop v1.0.3 +## explicit; go 1.13 +github.com/felixge/httpsnoop +# github.com/fiorix/go-smpp v0.0.0-20210403173735-2894b96e70ba +## explicit; go 1.16 +github.com/fiorix/go-smpp/smpp +github.com/fiorix/go-smpp/smpp/encoding +github.com/fiorix/go-smpp/smpp/pdu +github.com/fiorix/go-smpp/smpp/pdu/pdufield +github.com/fiorix/go-smpp/smpp/pdu/pdutext +github.com/fiorix/go-smpp/smpp/pdu/pdutlv +# github.com/flosch/pongo2/v4 v4.0.2 +## explicit; go 1.14 +github.com/flosch/pongo2/v4 +# github.com/fsnotify/fsnotify v1.6.0 +## explicit; go 1.16 +github.com/fsnotify/fsnotify +# github.com/fxamacker/cbor/v2 v2.5.0 +## explicit; go 1.12 +github.com/fxamacker/cbor/v2 +# github.com/gabriel-vasile/mimetype v1.4.3 +## explicit; go 1.20 +github.com/gabriel-vasile/mimetype +github.com/gabriel-vasile/mimetype/internal/charset +github.com/gabriel-vasile/mimetype/internal/json +github.com/gabriel-vasile/mimetype/internal/magic +# github.com/gin-contrib/sse v0.1.0 +## explicit; go 1.12 +github.com/gin-contrib/sse +# github.com/gin-gonic/gin v1.9.1 +## explicit; go 1.20 +github.com/gin-gonic/gin +github.com/gin-gonic/gin/binding +github.com/gin-gonic/gin/internal/bytesconv +github.com/gin-gonic/gin/internal/json +github.com/gin-gonic/gin/render +# github.com/go-chi/chi/v5 v5.0.10 +## explicit; go 1.14 +github.com/go-chi/chi/v5 +# github.com/go-gorp/gorp/v3 v3.1.0 +## explicit; go 1.18 +github.com/go-gorp/gorp/v3 +# github.com/go-jose/go-jose/v3 v3.0.0 +## explicit; go 1.12 +github.com/go-jose/go-jose/v3 +github.com/go-jose/go-jose/v3/cipher +github.com/go-jose/go-jose/v3/json +github.com/go-jose/go-jose/v3/jwt +# github.com/go-kit/kit v0.13.0 +## explicit; go 1.17 +github.com/go-kit/kit/endpoint +github.com/go-kit/kit/metrics +github.com/go-kit/kit/metrics/internal/lv +github.com/go-kit/kit/metrics/prometheus +github.com/go-kit/kit/transport +github.com/go-kit/kit/transport/grpc +github.com/go-kit/kit/transport/http +# github.com/go-kit/log v0.2.1 +## explicit; go 1.17 +github.com/go-kit/log +# github.com/go-logfmt/logfmt v0.6.0 +## explicit; go 1.17 +github.com/go-logfmt/logfmt +# github.com/go-logr/logr v1.2.4 +## explicit; go 1.16 +github.com/go-logr/logr +github.com/go-logr/logr/funcr +# github.com/go-logr/stdr v1.2.2 +## explicit; go 1.16 +github.com/go-logr/stdr +# github.com/go-playground/locales v0.14.1 +## explicit; go 1.17 +github.com/go-playground/locales +github.com/go-playground/locales/currency +# github.com/go-playground/universal-translator v0.18.1 +## explicit; go 1.18 +github.com/go-playground/universal-translator +# github.com/go-playground/validator/v10 v10.15.5 +## explicit; go 1.18 +github.com/go-playground/validator/v10 +# github.com/go-redis/redis/v8 v8.11.5 +## explicit; go 1.17 +github.com/go-redis/redis/v8 +github.com/go-redis/redis/v8/internal +github.com/go-redis/redis/v8/internal/hashtag +github.com/go-redis/redis/v8/internal/hscan +github.com/go-redis/redis/v8/internal/pool +github.com/go-redis/redis/v8/internal/proto +github.com/go-redis/redis/v8/internal/rand +github.com/go-redis/redis/v8/internal/util +# github.com/go-zoo/bone v1.3.0 +## explicit; go 1.9 +github.com/go-zoo/bone +# github.com/goccy/go-json v0.10.2 +## explicit; go 1.12 +github.com/goccy/go-json +github.com/goccy/go-json/internal/decoder +github.com/goccy/go-json/internal/encoder +github.com/goccy/go-json/internal/encoder/vm +github.com/goccy/go-json/internal/encoder/vm_color +github.com/goccy/go-json/internal/encoder/vm_color_indent +github.com/goccy/go-json/internal/encoder/vm_indent +github.com/goccy/go-json/internal/errors +github.com/goccy/go-json/internal/runtime +# github.com/gocql/gocql v1.6.0 +## explicit; go 1.13 +github.com/gocql/gocql +github.com/gocql/gocql/internal/lru +github.com/gocql/gocql/internal/murmur +github.com/gocql/gocql/internal/streams +# github.com/gofrs/uuid v4.4.0+incompatible +## explicit +github.com/gofrs/uuid +# github.com/gogo/protobuf v1.3.2 +## explicit; go 1.15 +github.com/gogo/protobuf/proto +# github.com/golang/protobuf v1.5.3 +## explicit; go 1.9 +github.com/golang/protobuf/jsonpb +github.com/golang/protobuf/proto +github.com/golang/protobuf/ptypes +github.com/golang/protobuf/ptypes/any +github.com/golang/protobuf/ptypes/duration +github.com/golang/protobuf/ptypes/timestamp +# github.com/golang/snappy v0.0.4 +## explicit +github.com/golang/snappy +# github.com/gomarkdown/markdown v0.0.0-20230922112808-5421fefb8386 +## explicit; go 1.12 +github.com/gomarkdown/markdown +github.com/gomarkdown/markdown/ast +github.com/gomarkdown/markdown/html +github.com/gomarkdown/markdown/parser +# github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 +## explicit; go 1.13 +github.com/google/shlex +# github.com/google/uuid v1.3.1 +## explicit +github.com/google/uuid +# github.com/gookit/color v1.5.4 +## explicit; go 1.18 +github.com/gookit/color +# github.com/gopcua/opcua v0.1.6 +## explicit; go 1.12 +github.com/gopcua/opcua +github.com/gopcua/opcua/debug +github.com/gopcua/opcua/errors +github.com/gopcua/opcua/id +github.com/gopcua/opcua/ua +github.com/gopcua/opcua/uacp +github.com/gopcua/opcua/uapolicy +github.com/gopcua/opcua/uasc +# github.com/gorilla/css v1.0.0 +## explicit +github.com/gorilla/css/scanner +# github.com/gorilla/websocket v1.5.0 +## explicit; go 1.12 +github.com/gorilla/websocket +# github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 +## explicit; go 1.14 +github.com/grpc-ecosystem/go-grpc-middleware +github.com/grpc-ecosystem/go-grpc-middleware/auth +github.com/grpc-ecosystem/go-grpc-middleware/util/metautils +github.com/grpc-ecosystem/go-grpc-middleware/validator +# github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.0 +## explicit; go 1.17 +github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule +github.com/grpc-ecosystem/grpc-gateway/v2/protoc-gen-openapiv2/options +github.com/grpc-ecosystem/grpc-gateway/v2/runtime +github.com/grpc-ecosystem/grpc-gateway/v2/utilities +# github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed +## explicit +github.com/hailocab/go-hostpool +# github.com/hashicorp/errwrap v1.1.0 +## explicit +github.com/hashicorp/errwrap +# github.com/hashicorp/go-cleanhttp v0.5.2 +## explicit; go 1.13 +github.com/hashicorp/go-cleanhttp +# github.com/hashicorp/go-multierror v1.1.1 +## explicit; go 1.13 +github.com/hashicorp/go-multierror +# github.com/hashicorp/go-retryablehttp v0.7.4 +## explicit; go 1.13 +github.com/hashicorp/go-retryablehttp +# github.com/hashicorp/go-rootcerts v1.0.2 +## explicit; go 1.12 +github.com/hashicorp/go-rootcerts +# github.com/hashicorp/go-secure-stdlib/parseutil v0.1.7 +## explicit; go 1.16 +github.com/hashicorp/go-secure-stdlib/parseutil +# github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 +## explicit; go 1.16 +github.com/hashicorp/go-secure-stdlib/strutil +# github.com/hashicorp/go-sockaddr v1.0.5 +## explicit; go 1.19 +github.com/hashicorp/go-sockaddr +# github.com/hashicorp/hcl v1.0.0 +## explicit +github.com/hashicorp/hcl +github.com/hashicorp/hcl/hcl/ast +github.com/hashicorp/hcl/hcl/parser +github.com/hashicorp/hcl/hcl/printer +github.com/hashicorp/hcl/hcl/scanner +github.com/hashicorp/hcl/hcl/strconv +github.com/hashicorp/hcl/hcl/token +github.com/hashicorp/hcl/json/parser +github.com/hashicorp/hcl/json/scanner +github.com/hashicorp/hcl/json/token +# github.com/hashicorp/vault/api v1.10.0 +## explicit; go 1.19 +github.com/hashicorp/vault/api +# github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f +## explicit +github.com/hokaccha/go-prettyjson +# github.com/imdario/mergo v0.3.16 +## explicit; go 1.13 +github.com/imdario/mergo +# github.com/inconshreveable/mousetrap v1.1.0 +## explicit; go 1.18 +github.com/inconshreveable/mousetrap +# github.com/influxdata/influxdb-client-go/v2 v2.12.3 +## explicit; go 1.17 +github.com/influxdata/influxdb-client-go/v2 +github.com/influxdata/influxdb-client-go/v2/api +github.com/influxdata/influxdb-client-go/v2/api/http +github.com/influxdata/influxdb-client-go/v2/api/query +github.com/influxdata/influxdb-client-go/v2/api/write +github.com/influxdata/influxdb-client-go/v2/domain +github.com/influxdata/influxdb-client-go/v2/internal/gzip +github.com/influxdata/influxdb-client-go/v2/internal/http +github.com/influxdata/influxdb-client-go/v2/internal/log +github.com/influxdata/influxdb-client-go/v2/internal/write +github.com/influxdata/influxdb-client-go/v2/log +# github.com/influxdata/line-protocol v0.0.0-20210922203350-b1ad95c89adf +## explicit; go 1.13 +github.com/influxdata/line-protocol +# github.com/iris-contrib/schema v0.0.6 +## explicit; go 1.14 +github.com/iris-contrib/schema +# github.com/ivanpirog/coloredcobra v1.0.1 +## explicit; go 1.15 +github.com/ivanpirog/coloredcobra +# github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa +## explicit; go 1.12 +github.com/jackc/pgerrcode +# github.com/jackc/pgio v1.0.0 +## explicit; go 1.12 +github.com/jackc/pgio +# github.com/jackc/pgpassfile v1.0.0 +## explicit; go 1.12 +github.com/jackc/pgpassfile +# github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a +## explicit; go 1.14 +github.com/jackc/pgservicefile +# github.com/jackc/pgtype v1.14.0 +## explicit; go 1.13 +github.com/jackc/pgtype +# github.com/jackc/pgx/v5 v5.4.3 +## explicit; go 1.19 +github.com/jackc/pgx/v5 +github.com/jackc/pgx/v5/internal/anynil +github.com/jackc/pgx/v5/internal/iobufpool +github.com/jackc/pgx/v5/internal/pgio +github.com/jackc/pgx/v5/internal/sanitize +github.com/jackc/pgx/v5/internal/stmtcache +github.com/jackc/pgx/v5/pgconn +github.com/jackc/pgx/v5/pgconn/internal/bgreader +github.com/jackc/pgx/v5/pgconn/internal/ctxwatch +github.com/jackc/pgx/v5/pgproto3 +github.com/jackc/pgx/v5/pgtype +github.com/jackc/pgx/v5/stdlib +# github.com/jmoiron/sqlx v1.3.5 +## explicit; go 1.10 +github.com/jmoiron/sqlx +github.com/jmoiron/sqlx/reflectx +# github.com/josharian/intern v1.0.0 +## explicit; go 1.5 +github.com/josharian/intern +# github.com/json-iterator/go v1.1.12 +## explicit; go 1.12 +github.com/json-iterator/go +# github.com/jzelinskie/stringz v0.0.2 +## explicit; go 1.12 +github.com/jzelinskie/stringz +# github.com/kataras/blocks v0.0.8 +## explicit; go 1.20 +github.com/kataras/blocks +# github.com/kataras/golog v0.1.9 +## explicit; go 1.20 +github.com/kataras/golog +# github.com/kataras/iris/v12 v12.2.7 +## explicit; go 1.20 +github.com/kataras/iris/v12 +github.com/kataras/iris/v12/cache +github.com/kataras/iris/v12/cache/cfg +github.com/kataras/iris/v12/cache/client +github.com/kataras/iris/v12/cache/client/rule +github.com/kataras/iris/v12/cache/entry +github.com/kataras/iris/v12/cache/ruleset +github.com/kataras/iris/v12/cache/uri +github.com/kataras/iris/v12/context +github.com/kataras/iris/v12/core/errgroup +github.com/kataras/iris/v12/core/handlerconv +github.com/kataras/iris/v12/core/host +github.com/kataras/iris/v12/core/memstore +github.com/kataras/iris/v12/core/netutil +github.com/kataras/iris/v12/core/router +github.com/kataras/iris/v12/hero +github.com/kataras/iris/v12/i18n +github.com/kataras/iris/v12/i18n/internal +github.com/kataras/iris/v12/macro +github.com/kataras/iris/v12/macro/handler +github.com/kataras/iris/v12/macro/interpreter/ast +github.com/kataras/iris/v12/macro/interpreter/lexer +github.com/kataras/iris/v12/macro/interpreter/parser +github.com/kataras/iris/v12/macro/interpreter/token +github.com/kataras/iris/v12/middleware/cors +github.com/kataras/iris/v12/middleware/modrevision +github.com/kataras/iris/v12/middleware/recover +github.com/kataras/iris/v12/middleware/requestid +github.com/kataras/iris/v12/sessions +github.com/kataras/iris/v12/view +github.com/kataras/iris/v12/x/client +github.com/kataras/iris/v12/x/errors +# github.com/kataras/pio v0.0.12 +## explicit; go 1.20 +github.com/kataras/pio +github.com/kataras/pio/terminal +# github.com/kataras/sitemap v0.0.6 +## explicit; go 1.19 +github.com/kataras/sitemap +# github.com/kataras/tunnel v0.0.4 +## explicit; go 1.18 +github.com/kataras/tunnel +# github.com/klauspost/compress v1.17.0 +## explicit; go 1.18 +github.com/klauspost/compress +github.com/klauspost/compress/flate +github.com/klauspost/compress/fse +github.com/klauspost/compress/gzip +github.com/klauspost/compress/huff0 +github.com/klauspost/compress/internal/cpuinfo +github.com/klauspost/compress/internal/snapref +github.com/klauspost/compress/s2 +github.com/klauspost/compress/zstd +github.com/klauspost/compress/zstd/internal/xxhash +# github.com/klauspost/cpuid/v2 v2.2.5 +## explicit; go 1.15 +github.com/klauspost/cpuid/v2 +# github.com/labstack/echo/v4 v4.11.2 +## explicit; go 1.17 +github.com/labstack/echo/v4 +# github.com/labstack/gommon v0.4.0 +## explicit; go 1.12 +github.com/labstack/gommon/color +github.com/labstack/gommon/log +# github.com/leodido/go-urn v1.2.4 +## explicit; go 1.16 +github.com/leodido/go-urn +# github.com/lestrrat-go/blackmagic v1.0.2 +## explicit; go 1.16 +github.com/lestrrat-go/blackmagic +# github.com/lestrrat-go/httpcc v1.0.1 +## explicit; go 1.16 +github.com/lestrrat-go/httpcc +# github.com/lestrrat-go/httprc v1.0.4 +## explicit; go 1.17 +github.com/lestrrat-go/httprc +# github.com/lestrrat-go/iter v1.0.2 +## explicit; go 1.13 +github.com/lestrrat-go/iter/arrayiter +github.com/lestrrat-go/iter/mapiter +# github.com/lestrrat-go/jwx/v2 v2.0.13 +## explicit; go 1.16 +github.com/lestrrat-go/jwx/v2 +github.com/lestrrat-go/jwx/v2/cert +github.com/lestrrat-go/jwx/v2/internal/base64 +github.com/lestrrat-go/jwx/v2/internal/ecutil +github.com/lestrrat-go/jwx/v2/internal/iter +github.com/lestrrat-go/jwx/v2/internal/json +github.com/lestrrat-go/jwx/v2/internal/keyconv +github.com/lestrrat-go/jwx/v2/internal/pool +github.com/lestrrat-go/jwx/v2/jwa +github.com/lestrrat-go/jwx/v2/jwe +github.com/lestrrat-go/jwx/v2/jwe/internal/aescbc +github.com/lestrrat-go/jwx/v2/jwe/internal/cipher +github.com/lestrrat-go/jwx/v2/jwe/internal/concatkdf +github.com/lestrrat-go/jwx/v2/jwe/internal/content_crypt +github.com/lestrrat-go/jwx/v2/jwe/internal/keyenc +github.com/lestrrat-go/jwx/v2/jwe/internal/keygen +github.com/lestrrat-go/jwx/v2/jwk +github.com/lestrrat-go/jwx/v2/jws +github.com/lestrrat-go/jwx/v2/jwt +github.com/lestrrat-go/jwx/v2/jwt/internal/types +github.com/lestrrat-go/jwx/v2/x25519 +# github.com/lestrrat-go/option v1.0.1 +## explicit; go 1.16 +github.com/lestrrat-go/option +# github.com/magiconair/properties v1.8.7 +## explicit; go 1.19 +github.com/magiconair/properties +# github.com/mailgun/raymond/v2 v2.0.48 +## explicit; go 1.16 +github.com/mailgun/raymond/v2 +github.com/mailgun/raymond/v2/ast +github.com/mailgun/raymond/v2/lexer +github.com/mailgun/raymond/v2/parser +# github.com/mailru/easyjson v0.7.7 +## explicit; go 1.12 +github.com/mailru/easyjson +github.com/mailru/easyjson/buffer +github.com/mailru/easyjson/jlexer +github.com/mailru/easyjson/jwriter +# github.com/mainflux/callhome v0.0.0-20230920140432-33c5663382ce +## explicit; go 1.21 +github.com/mainflux/callhome/pkg/client +# github.com/mainflux/mproxy v0.3.1-0.20231022160500-0e0db9e1642c +## explicit; go 1.19 +github.com/mainflux/mproxy/pkg/http +github.com/mainflux/mproxy/pkg/logger +github.com/mainflux/mproxy/pkg/mqtt +github.com/mainflux/mproxy/pkg/mqtt/websocket +github.com/mainflux/mproxy/pkg/session +github.com/mainflux/mproxy/pkg/tls +github.com/mainflux/mproxy/pkg/websockets +# github.com/mainflux/senml v1.5.0 +## explicit; go 1.13 +github.com/mainflux/senml +# github.com/mattn/go-colorable v0.1.13 +## explicit; go 1.15 +github.com/mattn/go-colorable +# github.com/mattn/go-isatty v0.0.19 +## explicit; go 1.15 +github.com/mattn/go-isatty +# github.com/matttproud/golang_protobuf_extensions v1.0.4 +## explicit; go 1.9 +github.com/matttproud/golang_protobuf_extensions/pbutil +# github.com/microcosm-cc/bluemonday v1.0.26 +## explicit; go 1.21 +github.com/microcosm-cc/bluemonday +github.com/microcosm-cc/bluemonday/css +# github.com/mitchellh/go-homedir v1.1.0 +## explicit +github.com/mitchellh/go-homedir +# github.com/mitchellh/mapstructure v1.5.0 +## explicit; go 1.14 +github.com/mitchellh/mapstructure +# github.com/moby/term v0.5.0 +## explicit; go 1.18 +github.com/moby/term +github.com/moby/term/windows +# github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd +## explicit +github.com/modern-go/concurrent +# github.com/modern-go/reflect2 v1.0.2 +## explicit; go 1.12 +github.com/modern-go/reflect2 +# github.com/montanaflynn/stats v0.7.1 +## explicit; go 1.13 +github.com/montanaflynn/stats +# github.com/nats-io/nats.go v1.30.2 +## explicit; go 1.20 +github.com/nats-io/nats.go +github.com/nats-io/nats.go/encoders/builtin +github.com/nats-io/nats.go/internal/parser +github.com/nats-io/nats.go/jetstream +github.com/nats-io/nats.go/util +# github.com/nats-io/nkeys v0.4.5 +## explicit; go 1.19 +github.com/nats-io/nkeys +# github.com/nats-io/nuid v1.0.1 +## explicit +github.com/nats-io/nuid +# github.com/oklog/ulid/v2 v2.1.0 +## explicit; go 1.15 +github.com/oklog/ulid/v2 +# github.com/opencontainers/go-digest v1.0.0 +## explicit; go 1.13 +github.com/opencontainers/go-digest +# github.com/opencontainers/image-spec v1.0.2 +## explicit +github.com/opencontainers/image-spec/specs-go +github.com/opencontainers/image-spec/specs-go/v1 +# github.com/opencontainers/runc v1.1.7 +## explicit; go 1.17 +github.com/opencontainers/runc/libcontainer/user +# github.com/ory/dockertest/v3 v3.10.0 +## explicit; go 1.17 +github.com/ory/dockertest/v3 +github.com/ory/dockertest/v3/docker +github.com/ory/dockertest/v3/docker/opts +github.com/ory/dockertest/v3/docker/pkg/archive +github.com/ory/dockertest/v3/docker/pkg/fileutils +github.com/ory/dockertest/v3/docker/pkg/homedir +github.com/ory/dockertest/v3/docker/pkg/idtools +github.com/ory/dockertest/v3/docker/pkg/ioutils +github.com/ory/dockertest/v3/docker/pkg/jsonmessage +github.com/ory/dockertest/v3/docker/pkg/longpath +github.com/ory/dockertest/v3/docker/pkg/mount +github.com/ory/dockertest/v3/docker/pkg/pools +github.com/ory/dockertest/v3/docker/pkg/stdcopy +github.com/ory/dockertest/v3/docker/pkg/system +github.com/ory/dockertest/v3/docker/types +github.com/ory/dockertest/v3/docker/types/blkiodev +github.com/ory/dockertest/v3/docker/types/container +github.com/ory/dockertest/v3/docker/types/filters +github.com/ory/dockertest/v3/docker/types/mount +github.com/ory/dockertest/v3/docker/types/network +github.com/ory/dockertest/v3/docker/types/registry +github.com/ory/dockertest/v3/docker/types/strslice +github.com/ory/dockertest/v3/docker/types/versions +# github.com/pelletier/go-toml v1.9.5 +## explicit; go 1.12 +github.com/pelletier/go-toml +# github.com/pelletier/go-toml/v2 v2.1.0 +## explicit; go 1.16 +github.com/pelletier/go-toml/v2 +github.com/pelletier/go-toml/v2/internal/characters +github.com/pelletier/go-toml/v2/internal/danger +github.com/pelletier/go-toml/v2/internal/tracker +github.com/pelletier/go-toml/v2/unstable +# github.com/pion/dtls/v2 v2.2.8-0.20230905141523-2b584af66577 +## explicit; go 1.13 +github.com/pion/dtls/v2 +github.com/pion/dtls/v2/internal/ciphersuite +github.com/pion/dtls/v2/internal/ciphersuite/types +github.com/pion/dtls/v2/internal/closer +github.com/pion/dtls/v2/internal/net +github.com/pion/dtls/v2/internal/net/udp +github.com/pion/dtls/v2/internal/util +github.com/pion/dtls/v2/pkg/crypto/ccm +github.com/pion/dtls/v2/pkg/crypto/ciphersuite +github.com/pion/dtls/v2/pkg/crypto/clientcertificate +github.com/pion/dtls/v2/pkg/crypto/elliptic +github.com/pion/dtls/v2/pkg/crypto/hash +github.com/pion/dtls/v2/pkg/crypto/prf +github.com/pion/dtls/v2/pkg/crypto/signature +github.com/pion/dtls/v2/pkg/crypto/signaturehash +github.com/pion/dtls/v2/pkg/net +github.com/pion/dtls/v2/pkg/protocol +github.com/pion/dtls/v2/pkg/protocol/alert +github.com/pion/dtls/v2/pkg/protocol/extension +github.com/pion/dtls/v2/pkg/protocol/handshake +github.com/pion/dtls/v2/pkg/protocol/recordlayer +# github.com/pion/logging v0.2.2 +## explicit; go 1.12 +github.com/pion/logging +# github.com/pion/transport/v3 v3.0.1 +## explicit; go 1.12 +github.com/pion/transport/v3/deadline +github.com/pion/transport/v3/netctx +github.com/pion/transport/v3/packetio +github.com/pion/transport/v3/replaydetector +github.com/pion/transport/v3/udp +# github.com/pkg/errors v0.9.1 +## explicit +github.com/pkg/errors +# github.com/plgd-dev/go-coap/v3 v3.1.5 +## explicit; go 1.18 +github.com/plgd-dev/go-coap/v3 +github.com/plgd-dev/go-coap/v3/dtls +github.com/plgd-dev/go-coap/v3/dtls/server +github.com/plgd-dev/go-coap/v3/message +github.com/plgd-dev/go-coap/v3/message/codes +github.com/plgd-dev/go-coap/v3/message/noresponse +github.com/plgd-dev/go-coap/v3/message/pool +github.com/plgd-dev/go-coap/v3/mux +github.com/plgd-dev/go-coap/v3/net +github.com/plgd-dev/go-coap/v3/net/blockwise +github.com/plgd-dev/go-coap/v3/net/client +github.com/plgd-dev/go-coap/v3/net/client/limitParallelRequests +github.com/plgd-dev/go-coap/v3/net/monitor/inactivity +github.com/plgd-dev/go-coap/v3/net/observation +github.com/plgd-dev/go-coap/v3/net/responsewriter +github.com/plgd-dev/go-coap/v3/options +github.com/plgd-dev/go-coap/v3/options/config +github.com/plgd-dev/go-coap/v3/pkg/cache +github.com/plgd-dev/go-coap/v3/pkg/connections +github.com/plgd-dev/go-coap/v3/pkg/errors +github.com/plgd-dev/go-coap/v3/pkg/fn +github.com/plgd-dev/go-coap/v3/pkg/rand +github.com/plgd-dev/go-coap/v3/pkg/runner/periodic +github.com/plgd-dev/go-coap/v3/pkg/sync +github.com/plgd-dev/go-coap/v3/tcp +github.com/plgd-dev/go-coap/v3/tcp/client +github.com/plgd-dev/go-coap/v3/tcp/coder +github.com/plgd-dev/go-coap/v3/tcp/server +github.com/plgd-dev/go-coap/v3/udp +github.com/plgd-dev/go-coap/v3/udp/client +github.com/plgd-dev/go-coap/v3/udp/coder +github.com/plgd-dev/go-coap/v3/udp/server +# github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 +## explicit +github.com/pmezard/go-difflib/difflib +# github.com/prometheus/client_golang v1.16.0 +## explicit; go 1.17 +github.com/prometheus/client_golang/prometheus +github.com/prometheus/client_golang/prometheus/internal +github.com/prometheus/client_golang/prometheus/promhttp +# github.com/prometheus/client_model v0.4.0 +## explicit; go 1.18 +github.com/prometheus/client_model/go +# github.com/prometheus/common v0.44.0 +## explicit; go 1.18 +github.com/prometheus/common/expfmt +github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg +github.com/prometheus/common/model +# github.com/prometheus/procfs v0.12.0 +## explicit; go 1.19 +github.com/prometheus/procfs +github.com/prometheus/procfs/internal/fs +github.com/prometheus/procfs/internal/util +# github.com/rabbitmq/amqp091-go v1.8.1 +## explicit; go 1.16 +github.com/rabbitmq/amqp091-go +# github.com/rubenv/sql-migrate v1.5.1 +## explicit; go 1.17 +github.com/rubenv/sql-migrate +github.com/rubenv/sql-migrate/sqlparse +# github.com/russross/blackfriday/v2 v2.1.0 +## explicit +github.com/russross/blackfriday/v2 +# github.com/ryanuber/go-glob v1.0.0 +## explicit +github.com/ryanuber/go-glob +# github.com/sagikazarmark/locafero v0.3.0 +## explicit; go 1.20 +github.com/sagikazarmark/locafero +# github.com/sagikazarmark/slog-shim v0.1.0 +## explicit; go 1.20 +github.com/sagikazarmark/slog-shim +# github.com/schollz/closestmatch v2.1.0+incompatible +## explicit +github.com/schollz/closestmatch +# github.com/segmentio/asm v1.2.0 +## explicit; go 1.18 +github.com/segmentio/asm/base64 +github.com/segmentio/asm/cpu +github.com/segmentio/asm/cpu/arm +github.com/segmentio/asm/cpu/arm64 +github.com/segmentio/asm/cpu/cpuid +github.com/segmentio/asm/cpu/x86 +github.com/segmentio/asm/internal/unsafebytes +# github.com/sirupsen/logrus v1.9.3 +## explicit; go 1.13 +github.com/sirupsen/logrus +# github.com/sourcegraph/conc v0.3.0 +## explicit; go 1.19 +github.com/sourcegraph/conc +github.com/sourcegraph/conc/internal/multierror +github.com/sourcegraph/conc/iter +github.com/sourcegraph/conc/panics +# github.com/spf13/afero v1.10.0 +## explicit; go 1.16 +github.com/spf13/afero +github.com/spf13/afero/internal/common +github.com/spf13/afero/mem +# github.com/spf13/cast v1.5.1 +## explicit; go 1.18 +github.com/spf13/cast +# github.com/spf13/cobra v1.7.0 +## explicit; go 1.15 +github.com/spf13/cobra +# github.com/spf13/pflag v1.0.5 +## explicit; go 1.12 +github.com/spf13/pflag +# github.com/spf13/viper v1.17.0 +## explicit; go 1.18 +github.com/spf13/viper +github.com/spf13/viper/internal/encoding +github.com/spf13/viper/internal/encoding/dotenv +github.com/spf13/viper/internal/encoding/hcl +github.com/spf13/viper/internal/encoding/ini +github.com/spf13/viper/internal/encoding/javaproperties +github.com/spf13/viper/internal/encoding/json +github.com/spf13/viper/internal/encoding/toml +github.com/spf13/viper/internal/encoding/yaml +# github.com/stretchr/objx v0.5.1 +## explicit; go 1.13 +github.com/stretchr/objx +# github.com/stretchr/testify v1.8.4 +## explicit; go 1.20 +github.com/stretchr/testify/assert +github.com/stretchr/testify/mock +github.com/stretchr/testify/require +# github.com/subosito/gotenv v1.6.0 +## explicit; go 1.18 +github.com/subosito/gotenv +# github.com/tdewolff/minify/v2 v2.12.9 +## explicit; go 1.18 +github.com/tdewolff/minify/v2 +github.com/tdewolff/minify/v2/css +github.com/tdewolff/minify/v2/html +github.com/tdewolff/minify/v2/js +github.com/tdewolff/minify/v2/json +github.com/tdewolff/minify/v2/svg +github.com/tdewolff/minify/v2/xml +# github.com/tdewolff/parse/v2 v2.6.8 +## explicit; go 1.13 +github.com/tdewolff/parse/v2 +github.com/tdewolff/parse/v2/buffer +github.com/tdewolff/parse/v2/css +github.com/tdewolff/parse/v2/html +github.com/tdewolff/parse/v2/js +github.com/tdewolff/parse/v2/json +github.com/tdewolff/parse/v2/strconv +github.com/tdewolff/parse/v2/xml +# github.com/twitchyliquid64/golang-asm v0.15.1 +## explicit; go 1.13 +github.com/twitchyliquid64/golang-asm/asm/arch +github.com/twitchyliquid64/golang-asm/bio +github.com/twitchyliquid64/golang-asm/dwarf +github.com/twitchyliquid64/golang-asm/goobj +github.com/twitchyliquid64/golang-asm/obj +github.com/twitchyliquid64/golang-asm/obj/arm +github.com/twitchyliquid64/golang-asm/obj/arm64 +github.com/twitchyliquid64/golang-asm/obj/mips +github.com/twitchyliquid64/golang-asm/obj/ppc64 +github.com/twitchyliquid64/golang-asm/obj/riscv +github.com/twitchyliquid64/golang-asm/obj/s390x +github.com/twitchyliquid64/golang-asm/obj/wasm +github.com/twitchyliquid64/golang-asm/obj/x86 +github.com/twitchyliquid64/golang-asm/objabi +github.com/twitchyliquid64/golang-asm/src +github.com/twitchyliquid64/golang-asm/sys +github.com/twitchyliquid64/golang-asm/unsafeheader +# github.com/ugorji/go/codec v1.2.11 +## explicit; go 1.11 +github.com/ugorji/go/codec +# github.com/valyala/bytebufferpool v1.0.0 +## explicit +github.com/valyala/bytebufferpool +# github.com/valyala/fasttemplate v1.2.2 +## explicit; go 1.12 +github.com/valyala/fasttemplate +# github.com/vmihailenco/msgpack/v5 v5.4.0 +## explicit; go 1.19 +github.com/vmihailenco/msgpack/v5 +github.com/vmihailenco/msgpack/v5/msgpcode +# github.com/vmihailenco/tagparser/v2 v2.0.0 +## explicit; go 1.15 +github.com/vmihailenco/tagparser/v2 +github.com/vmihailenco/tagparser/v2/internal +github.com/vmihailenco/tagparser/v2/internal/parser +# github.com/x448/float16 v0.8.4 +## explicit; go 1.11 +github.com/x448/float16 +# github.com/xdg-go/pbkdf2 v1.0.0 +## explicit; go 1.9 +github.com/xdg-go/pbkdf2 +# github.com/xdg-go/scram v1.1.2 +## explicit; go 1.11 +github.com/xdg-go/scram +# github.com/xdg-go/stringprep v1.0.4 +## explicit; go 1.11 +github.com/xdg-go/stringprep +# github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb +## explicit +github.com/xeipuuv/gojsonpointer +# github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 +## explicit +github.com/xeipuuv/gojsonreference +# github.com/xeipuuv/gojsonschema v1.2.0 +## explicit +github.com/xeipuuv/gojsonschema +# github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e +## explicit; go 1.19 +github.com/xo/terminfo +# github.com/yosssi/ace v0.0.5 +## explicit +github.com/yosssi/ace +# github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a +## explicit; go 1.12 +github.com/youmark/pkcs8 +# go.mongodb.org/mongo-driver v1.12.0 +## explicit; go 1.13 +go.mongodb.org/mongo-driver/bson +go.mongodb.org/mongo-driver/bson/bsoncodec +go.mongodb.org/mongo-driver/bson/bsonoptions +go.mongodb.org/mongo-driver/bson/bsonrw +go.mongodb.org/mongo-driver/bson/bsontype +go.mongodb.org/mongo-driver/bson/primitive +go.mongodb.org/mongo-driver/event +go.mongodb.org/mongo-driver/internal +go.mongodb.org/mongo-driver/internal/aws +go.mongodb.org/mongo-driver/internal/aws/awserr +go.mongodb.org/mongo-driver/internal/aws/credentials +go.mongodb.org/mongo-driver/internal/aws/signer/v4 +go.mongodb.org/mongo-driver/internal/credproviders +go.mongodb.org/mongo-driver/internal/logger +go.mongodb.org/mongo-driver/internal/randutil +go.mongodb.org/mongo-driver/internal/randutil/rand +go.mongodb.org/mongo-driver/internal/uuid +go.mongodb.org/mongo-driver/mongo +go.mongodb.org/mongo-driver/mongo/address +go.mongodb.org/mongo-driver/mongo/description +go.mongodb.org/mongo-driver/mongo/options +go.mongodb.org/mongo-driver/mongo/readconcern +go.mongodb.org/mongo-driver/mongo/readpref +go.mongodb.org/mongo-driver/mongo/writeconcern +go.mongodb.org/mongo-driver/tag +go.mongodb.org/mongo-driver/version +go.mongodb.org/mongo-driver/x/bsonx/bsoncore +go.mongodb.org/mongo-driver/x/mongo/driver +go.mongodb.org/mongo-driver/x/mongo/driver/auth +go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds +go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi +go.mongodb.org/mongo-driver/x/mongo/driver/connstring +go.mongodb.org/mongo-driver/x/mongo/driver/dns +go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt +go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options +go.mongodb.org/mongo-driver/x/mongo/driver/ocsp +go.mongodb.org/mongo-driver/x/mongo/driver/operation +go.mongodb.org/mongo-driver/x/mongo/driver/session +go.mongodb.org/mongo-driver/x/mongo/driver/topology +go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage +# go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.42.0 +## explicit; go 1.19 +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc/internal +# go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 +## explicit; go 1.19 +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp +# go.opentelemetry.io/otel v1.19.0 +## explicit; go 1.20 +go.opentelemetry.io/otel +go.opentelemetry.io/otel/attribute +go.opentelemetry.io/otel/baggage +go.opentelemetry.io/otel/codes +go.opentelemetry.io/otel/internal +go.opentelemetry.io/otel/internal/attribute +go.opentelemetry.io/otel/internal/baggage +go.opentelemetry.io/otel/internal/global +go.opentelemetry.io/otel/propagation +go.opentelemetry.io/otel/semconv/internal +go.opentelemetry.io/otel/semconv/internal/v2 +go.opentelemetry.io/otel/semconv/v1.12.0 +go.opentelemetry.io/otel/semconv/v1.17.0 +go.opentelemetry.io/otel/semconv/v1.17.0/httpconv +# go.opentelemetry.io/otel/exporters/jaeger v1.16.0 +## explicit; go 1.19 +go.opentelemetry.io/otel/exporters/jaeger +go.opentelemetry.io/otel/exporters/jaeger/internal/gen-go/agent +go.opentelemetry.io/otel/exporters/jaeger/internal/gen-go/jaeger +go.opentelemetry.io/otel/exporters/jaeger/internal/gen-go/zipkincore +go.opentelemetry.io/otel/exporters/jaeger/internal/third_party/thrift/lib/go/thrift +# go.opentelemetry.io/otel/metric v1.19.0 +## explicit; go 1.20 +go.opentelemetry.io/otel/metric +go.opentelemetry.io/otel/metric/embedded +# go.opentelemetry.io/otel/sdk v1.16.0 +## explicit; go 1.19 +go.opentelemetry.io/otel/sdk +go.opentelemetry.io/otel/sdk/instrumentation +go.opentelemetry.io/otel/sdk/internal +go.opentelemetry.io/otel/sdk/internal/env +go.opentelemetry.io/otel/sdk/resource +go.opentelemetry.io/otel/sdk/trace +# go.opentelemetry.io/otel/trace v1.19.0 +## explicit; go 1.20 +go.opentelemetry.io/otel/trace +# go.uber.org/atomic v1.11.0 +## explicit; go 1.18 +go.uber.org/atomic +# go.uber.org/multierr v1.9.0 +## explicit; go 1.19 +go.uber.org/multierr +# golang.org/x/arch v0.4.0 +## explicit; go 1.17 +golang.org/x/arch/x86/x86asm +# golang.org/x/crypto v0.14.0 +## explicit; go 1.17 +golang.org/x/crypto/acme +golang.org/x/crypto/acme/autocert +golang.org/x/crypto/bcrypt +golang.org/x/crypto/blake2b +golang.org/x/crypto/blowfish +golang.org/x/crypto/cryptobyte +golang.org/x/crypto/cryptobyte/asn1 +golang.org/x/crypto/curve25519 +golang.org/x/crypto/curve25519/internal/field +golang.org/x/crypto/ed25519 +golang.org/x/crypto/internal/alias +golang.org/x/crypto/internal/poly1305 +golang.org/x/crypto/nacl/box +golang.org/x/crypto/nacl/secretbox +golang.org/x/crypto/ocsp +golang.org/x/crypto/pbkdf2 +golang.org/x/crypto/salsa20/salsa +golang.org/x/crypto/scrypt +golang.org/x/crypto/sha3 +# golang.org/x/exp v0.0.0-20230905200255-921286631fa9 +## explicit; go 1.20 +golang.org/x/exp/constraints +golang.org/x/exp/maps +golang.org/x/exp/slices +golang.org/x/exp/slog +golang.org/x/exp/slog/internal +golang.org/x/exp/slog/internal/buffer +# golang.org/x/mod v0.12.0 +## explicit; go 1.17 +golang.org/x/mod/semver +# golang.org/x/net v0.17.0 +## explicit; go 1.17 +golang.org/x/net/bpf +golang.org/x/net/html +golang.org/x/net/html/atom +golang.org/x/net/http/httpguts +golang.org/x/net/http2 +golang.org/x/net/http2/h2c +golang.org/x/net/http2/hpack +golang.org/x/net/idna +golang.org/x/net/internal/iana +golang.org/x/net/internal/socket +golang.org/x/net/internal/socks +golang.org/x/net/internal/timeseries +golang.org/x/net/ipv4 +golang.org/x/net/ipv6 +golang.org/x/net/proxy +golang.org/x/net/publicsuffix +golang.org/x/net/trace +# golang.org/x/sync v0.3.0 +## explicit; go 1.17 +golang.org/x/sync/errgroup +golang.org/x/sync/semaphore +golang.org/x/sync/singleflight +# golang.org/x/sys v0.13.0 +## explicit; go 1.17 +golang.org/x/sys/cpu +golang.org/x/sys/execabs +golang.org/x/sys/unix +golang.org/x/sys/windows +golang.org/x/sys/windows/registry +# golang.org/x/text v0.13.0 +## explicit; go 1.17 +golang.org/x/text/cases +golang.org/x/text/encoding +golang.org/x/text/encoding/charmap +golang.org/x/text/encoding/internal +golang.org/x/text/encoding/internal/identifier +golang.org/x/text/encoding/unicode +golang.org/x/text/feature/plural +golang.org/x/text/internal +golang.org/x/text/internal/catmsg +golang.org/x/text/internal/format +golang.org/x/text/internal/language +golang.org/x/text/internal/language/compact +golang.org/x/text/internal/number +golang.org/x/text/internal/stringset +golang.org/x/text/internal/tag +golang.org/x/text/internal/utf8internal +golang.org/x/text/language +golang.org/x/text/message +golang.org/x/text/message/catalog +golang.org/x/text/runes +golang.org/x/text/secure/bidirule +golang.org/x/text/secure/precis +golang.org/x/text/transform +golang.org/x/text/unicode/bidi +golang.org/x/text/unicode/norm +golang.org/x/text/width +# golang.org/x/time v0.3.0 +## explicit +golang.org/x/time/rate +# golang.org/x/tools v0.13.0 +## explicit; go 1.18 +golang.org/x/tools/cmd/stringer +golang.org/x/tools/go/gcexportdata +golang.org/x/tools/go/internal/packagesdriver +golang.org/x/tools/go/packages +golang.org/x/tools/go/types/objectpath +golang.org/x/tools/internal/event +golang.org/x/tools/internal/event/core +golang.org/x/tools/internal/event/keys +golang.org/x/tools/internal/event/label +golang.org/x/tools/internal/event/tag +golang.org/x/tools/internal/gcimporter +golang.org/x/tools/internal/gocommand +golang.org/x/tools/internal/packagesinternal +golang.org/x/tools/internal/pkgbits +golang.org/x/tools/internal/tokeninternal +golang.org/x/tools/internal/typeparams +golang.org/x/tools/internal/typesinternal +# gonum.org/v1/gonum v0.13.0 +## explicit; go 1.18 +gonum.org/v1/gonum/blas +gonum.org/v1/gonum/blas/blas64 +gonum.org/v1/gonum/blas/cblas128 +gonum.org/v1/gonum/blas/gonum +gonum.org/v1/gonum/floats +gonum.org/v1/gonum/floats/scalar +gonum.org/v1/gonum/internal/asm/c128 +gonum.org/v1/gonum/internal/asm/c64 +gonum.org/v1/gonum/internal/asm/f32 +gonum.org/v1/gonum/internal/asm/f64 +gonum.org/v1/gonum/internal/cmplx64 +gonum.org/v1/gonum/internal/math32 +gonum.org/v1/gonum/lapack +gonum.org/v1/gonum/lapack/gonum +gonum.org/v1/gonum/lapack/lapack64 +gonum.org/v1/gonum/mat +gonum.org/v1/gonum/stat +# google.golang.org/genproto v0.0.0-20230913181813-007df8e322eb +## explicit; go 1.19 +google.golang.org/genproto/internal +# google.golang.org/genproto/googleapis/api v0.0.0-20230913181813-007df8e322eb +## explicit; go 1.19 +google.golang.org/genproto/googleapis/api +google.golang.org/genproto/googleapis/api/annotations +google.golang.org/genproto/googleapis/api/httpbody +# google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 +## explicit; go 1.19 +google.golang.org/genproto/googleapis/rpc/status +# google.golang.org/grpc v1.58.2 +## explicit; go 1.19 +google.golang.org/grpc +google.golang.org/grpc/attributes +google.golang.org/grpc/backoff +google.golang.org/grpc/balancer +google.golang.org/grpc/balancer/base +google.golang.org/grpc/balancer/grpclb/state +google.golang.org/grpc/balancer/roundrobin +google.golang.org/grpc/binarylog/grpc_binarylog_v1 +google.golang.org/grpc/channelz +google.golang.org/grpc/codes +google.golang.org/grpc/connectivity +google.golang.org/grpc/credentials +google.golang.org/grpc/credentials/insecure +google.golang.org/grpc/encoding +google.golang.org/grpc/encoding/proto +google.golang.org/grpc/grpclog +google.golang.org/grpc/health +google.golang.org/grpc/health/grpc_health_v1 +google.golang.org/grpc/internal +google.golang.org/grpc/internal/backoff +google.golang.org/grpc/internal/balancer/gracefulswitch +google.golang.org/grpc/internal/balancerload +google.golang.org/grpc/internal/binarylog +google.golang.org/grpc/internal/buffer +google.golang.org/grpc/internal/channelz +google.golang.org/grpc/internal/credentials +google.golang.org/grpc/internal/envconfig +google.golang.org/grpc/internal/grpclog +google.golang.org/grpc/internal/grpcrand +google.golang.org/grpc/internal/grpcsync +google.golang.org/grpc/internal/grpcutil +google.golang.org/grpc/internal/idle +google.golang.org/grpc/internal/metadata +google.golang.org/grpc/internal/pretty +google.golang.org/grpc/internal/resolver +google.golang.org/grpc/internal/resolver/dns +google.golang.org/grpc/internal/resolver/passthrough +google.golang.org/grpc/internal/resolver/unix +google.golang.org/grpc/internal/serviceconfig +google.golang.org/grpc/internal/status +google.golang.org/grpc/internal/syscall +google.golang.org/grpc/internal/transport +google.golang.org/grpc/internal/transport/networktype +google.golang.org/grpc/keepalive +google.golang.org/grpc/metadata +google.golang.org/grpc/peer +google.golang.org/grpc/reflection +google.golang.org/grpc/reflection/grpc_reflection_v1 +google.golang.org/grpc/reflection/grpc_reflection_v1alpha +google.golang.org/grpc/resolver +google.golang.org/grpc/serviceconfig +google.golang.org/grpc/stats +google.golang.org/grpc/status +google.golang.org/grpc/tap +# google.golang.org/protobuf v1.31.0 +## explicit; go 1.11 +google.golang.org/protobuf/encoding/protojson +google.golang.org/protobuf/encoding/prototext +google.golang.org/protobuf/encoding/protowire +google.golang.org/protobuf/internal/descfmt +google.golang.org/protobuf/internal/descopts +google.golang.org/protobuf/internal/detrand +google.golang.org/protobuf/internal/encoding/defval +google.golang.org/protobuf/internal/encoding/json +google.golang.org/protobuf/internal/encoding/messageset +google.golang.org/protobuf/internal/encoding/tag +google.golang.org/protobuf/internal/encoding/text +google.golang.org/protobuf/internal/errors +google.golang.org/protobuf/internal/filedesc +google.golang.org/protobuf/internal/filetype +google.golang.org/protobuf/internal/flags +google.golang.org/protobuf/internal/genid +google.golang.org/protobuf/internal/impl +google.golang.org/protobuf/internal/order +google.golang.org/protobuf/internal/pragma +google.golang.org/protobuf/internal/set +google.golang.org/protobuf/internal/strs +google.golang.org/protobuf/internal/version +google.golang.org/protobuf/proto +google.golang.org/protobuf/reflect/protodesc +google.golang.org/protobuf/reflect/protoreflect +google.golang.org/protobuf/reflect/protoregistry +google.golang.org/protobuf/runtime/protoiface +google.golang.org/protobuf/runtime/protoimpl +google.golang.org/protobuf/types/descriptorpb +google.golang.org/protobuf/types/known/anypb +google.golang.org/protobuf/types/known/durationpb +google.golang.org/protobuf/types/known/fieldmaskpb +google.golang.org/protobuf/types/known/structpb +google.golang.org/protobuf/types/known/timestamppb +google.golang.org/protobuf/types/known/wrapperspb +# gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc +## explicit +gopkg.in/alexcesaro/quotedprintable.v3 +# gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df +## explicit +gopkg.in/gomail.v2 +# gopkg.in/inf.v0 v0.9.1 +## explicit +gopkg.in/inf.v0 +# gopkg.in/ini.v1 v1.67.0 +## explicit +gopkg.in/ini.v1 +# gopkg.in/yaml.v2 v2.4.0 +## explicit; go 1.15 +gopkg.in/yaml.v2 +# gopkg.in/yaml.v3 v3.0.1 +## explicit +gopkg.in/yaml.v3