import (
"bufio"
"bytes"
- "context"
"errors"
"fmt"
"io"
"os"
"strconv"
- "time"
)
func listGroupsFromReader(u *User, r io.Reader) ([]string, error) {
}
func listGroups(u *User) ([]string, error) {
- if defaultUserdbClient.isUsable() {
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
- defer cancel()
- if ids, ok, err := defaultUserdbClient.lookupGroupIds(ctx, u.Username); ok {
- return ids, err
- }
- }
f, err := os.Open(groupFile)
if err != nil {
return nil, err
import (
"bufio"
"bytes"
- "context"
"errors"
"io"
"os"
"strconv"
"strings"
- "time"
)
// lineFunc returns a value, an error, or (nil, nil) to skip the row.
}
func lookupGroup(groupname string) (*Group, error) {
- if defaultUserdbClient.isUsable() {
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
- if g, ok, err := defaultUserdbClient.lookupGroup(ctx, groupname); ok {
- return g, err
- }
- }
f, err := os.Open(groupFile)
if err != nil {
return nil, err
}
func lookupGroupId(id string) (*Group, error) {
- if defaultUserdbClient.isUsable() {
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
- if g, ok, err := defaultUserdbClient.lookupGroupId(ctx, id); ok {
- return g, err
- }
- }
f, err := os.Open(groupFile)
if err != nil {
return nil, err
}
func lookupUser(username string) (*User, error) {
- if defaultUserdbClient.isUsable() {
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
- if u, ok, err := defaultUserdbClient.lookupUser(ctx, username); ok {
- return u, err
- }
- }
f, err := os.Open(userFile)
if err != nil {
return nil, err
}
func lookupUserId(uid string) (*User, error) {
- if defaultUserdbClient.isUsable() {
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
- if u, ok, err := defaultUserdbClient.lookupUserId(ctx, uid); ok {
- return u, err
- }
- }
f, err := os.Open(userFile)
if err != nil {
return nil, err
is cgo-based and relies on the standard C library (libc) routines such as
getpwuid_r, getgrnam_r, and getgrouplist.
-For Linux, the pure Go implementation queries the systemd-userdb service first.
-If the service is not available, it falls back to parsing /etc/passwd and
-/etc/group.
-
When cgo is available, and the required routines are implemented in libc
for a particular platform, cgo-based (libc-backed) code is used.
This can be overridden by using osusergo build tag, which enforces
+++ /dev/null
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package user
-
-// userdbClient queries the io.systemd.UserDatabase VARLINK interface provided by
-// systemd-userdbd.service(8) on Linux for obtaining full user/group details
-// even when cgo is not available.
-// VARLINK protocol: https://varlink.org
-// Systemd userdb VARLINK interface https://systemd.io/USER_GROUP_API
-// dir contains multiple varlink service sockets implementing the userdb interface.
-type userdbClient struct {
- dir string
-}
-
-// IsUsable checks if the client can be used to make queries.
-func (cl userdbClient) isUsable() bool {
- return len(cl.dir) != 0
-}
-
-var defaultUserdbClient userdbClient
+++ /dev/null
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build linux
-
-package user
-
-import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "io"
- "io/fs"
- "os"
- "strconv"
- "strings"
- "sync"
- "syscall"
- "unicode/utf16"
- "unicode/utf8"
-)
-
-const (
- // Well known multiplexer service.
- svcMultiplexer = "io.systemd.Multiplexer"
-
- userdbNamespace = "io.systemd.UserDatabase"
-
- // io.systemd.UserDatabase VARLINK interface methods.
- mGetGroupRecord = userdbNamespace + ".GetGroupRecord"
- mGetUserRecord = userdbNamespace + ".GetUserRecord"
- mGetMemberships = userdbNamespace + ".GetMemberships"
-
- // io.systemd.UserDatabase VARLINK interface errors.
- errNoRecordFound = userdbNamespace + ".NoRecordFound"
- errServiceNotAvailable = userdbNamespace + ".ServiceNotAvailable"
-)
-
-func init() {
- defaultUserdbClient.dir = "/run/systemd/userdb"
-}
-
-// userdbCall represents a VARLINK service call sent to systemd-userdb.
-// method is the VARLINK method to call.
-// parameters are the VARLINK parameters to pass.
-// more indicates if more responses are expected.
-// fastest indicates if only the fastest response should be returned.
-type userdbCall struct {
- method string
- parameters callParameters
- more bool
- fastest bool
-}
-
-func (u userdbCall) marshalJSON(service string) ([]byte, error) {
- params, err := u.parameters.marshalJSON(service)
- if err != nil {
- return nil, err
- }
- var data bytes.Buffer
- data.WriteString(`{"method":"`)
- data.WriteString(u.method)
- data.WriteString(`","parameters":`)
- data.Write(params)
- if u.more {
- data.WriteString(`,"more":true`)
- }
- data.WriteString(`}`)
- return data.Bytes(), nil
-}
-
-type callParameters struct {
- uid *int64
- userName string
- gid *int64
- groupName string
-}
-
-func (c callParameters) marshalJSON(service string) ([]byte, error) {
- var data bytes.Buffer
- data.WriteString(`{"service":"`)
- data.WriteString(service)
- data.WriteString(`"`)
- if c.uid != nil {
- data.WriteString(`,"uid":`)
- data.WriteString(strconv.FormatInt(*c.uid, 10))
- }
- if c.userName != "" {
- data.WriteString(`,"userName":"`)
- data.WriteString(c.userName)
- data.WriteString(`"`)
- }
- if c.gid != nil {
- data.WriteString(`,"gid":`)
- data.WriteString(strconv.FormatInt(*c.gid, 10))
- }
- if c.groupName != "" {
- data.WriteString(`,"groupName":"`)
- data.WriteString(c.groupName)
- data.WriteString(`"`)
- }
- data.WriteString(`}`)
- return data.Bytes(), nil
-}
-
-type userdbReply struct {
- continues bool
- errorStr string
-}
-
-func (u *userdbReply) unmarshalJSON(data []byte) error {
- var (
- kContinues = []byte(`"continues"`)
- kError = []byte(`"error"`)
- )
- if i := bytes.Index(data, kContinues); i != -1 {
- continues, err := parseJSONBoolean(data[i+len(kContinues):])
- if err != nil {
- return err
- }
- u.continues = continues
- }
- if i := bytes.Index(data, kError); i != -1 {
- errStr, err := parseJSONString(data[i+len(kError):])
- if err != nil {
- return err
- }
- u.errorStr = errStr
- }
- return nil
-}
-
-// response is the parsed reply from a method call to systemd-userdb.
-// data is one or more VARLINK response parameters separated by 0.
-// handled indicates if the call was handled by systemd-userdb.
-// err is any error encountered.
-type response struct {
- data []byte
- handled bool
- err error
-}
-
-// querySocket calls the io.systemd.UserDatabase VARLINK interface at sock with request.
-// Multiple replies can be read by setting more to true in the request.
-// Reply parameters are accumulated separated by 0, if there are many.
-// Replies with io.systemd.UserDatabase.NoRecordFound errors are skipped.
-// Other UserDatabase errors are returned as is.
-// If the socket does not exist, or if the io.systemd.UserDatabase.ServiceNotAvailable
-// error is seen in a response, the query is considered unhandled.
-func querySocket(ctx context.Context, sock string, request []byte) response {
- sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
- if err != nil {
- return response{err: err}
- }
- defer syscall.Close(sockFd)
- if err := syscall.Connect(sockFd, &syscall.SockaddrUnix{Name: sock}); err != nil {
- if errors.Is(err, os.ErrNotExist) {
- return response{err: err}
- }
- return response{handled: true, err: err}
- }
-
- // Null terminate request.
- if request[len(request)-1] != 0 {
- request = append(request, 0)
- }
-
- // Write request to socket.
- written := 0
- for written < len(request) {
- if ctx.Err() != nil {
- return response{handled: true, err: ctx.Err()}
- }
- if n, err := syscall.Write(sockFd, request[written:]); err != nil {
- return response{handled: true, err: err}
- } else {
- written += n
- }
- }
-
- // Read response.
- var resp bytes.Buffer
- for {
- if ctx.Err() != nil {
- return response{handled: true, err: ctx.Err()}
- }
- buf := make([]byte, 4096)
- if n, err := syscall.Read(sockFd, buf); err != nil {
- return response{handled: true, err: err}
- } else if n > 0 {
- resp.Write(buf[:n])
- if buf[n-1] == 0 {
- break
- }
- } else {
- // EOF
- break
- }
- }
-
- if resp.Len() == 0 {
- return response{handled: true}
- }
-
- buf := resp.Bytes()
- // Remove trailing 0.
- buf = buf[:len(buf)-1]
- // Split into VARLINK messages.
- msgs := bytes.Split(buf, []byte{0})
-
- // Parse VARLINK messages.
- for _, m := range msgs {
- var resp userdbReply
- if err := resp.unmarshalJSON(m); err != nil {
- return response{handled: true, err: err}
- }
- // Handle VARLINK message errors.
- switch e := resp.errorStr; e {
- case "":
- case errNoRecordFound: // Ignore not found error.
- continue
- case errServiceNotAvailable:
- return response{}
- default:
- return response{handled: true, err: errors.New(e)}
- }
- if !resp.continues {
- break
- }
- }
- return response{data: buf, handled: true, err: ctx.Err()}
-}
-
-// queryMany calls the io.systemd.UserDatabase VARLINK interface on many services at once.
-// ss is a slice of userdb services to call. Each service must have a socket in cl.dir.
-// c is sent to all services in ss. If c.fastest is true, only the fastest reply is read.
-// Otherwise all replies are aggregated. um is called with aggregated reply parameters.
-// queryMany returns the first error encountered. The first result is false if no userdb
-// socket is available or if all requests time out.
-func (cl userdbClient) queryMany(ctx context.Context, ss []string, c *userdbCall, um jsonUnmarshaler) (bool, error) {
- responseCh := make(chan response, len(ss))
-
- ctx, cancel := context.WithCancel(ctx)
- defer cancel()
-
- // Query all services in parallel.
- var workers sync.WaitGroup
- for _, svc := range ss {
- data, err := c.marshalJSON(svc)
- if err != nil {
- return true, err
- }
- // Spawn worker to query service.
- workers.Add(1)
- go func(sock string, data []byte) {
- defer workers.Done()
- responseCh <- querySocket(ctx, sock, data)
- }(cl.dir+"/"+svc, data)
- }
-
- go func() {
- // Clean up workers.
- workers.Wait()
- close(responseCh)
- }()
-
- var result bytes.Buffer
- var notOk int
-RecvResponses:
- for {
- select {
- case resp, ok := <-responseCh:
- if !ok {
- // Responses channel is closed so stop reading.
- break RecvResponses
- }
- if resp.err != nil {
- // querySocket only returns unrecoverable errors,
- // so return the first one received.
- return true, resp.err
- }
- if !resp.handled {
- notOk++
- continue
- }
-
- first := result.Len() == 0
- result.Write(resp.data)
- if first && c.fastest {
- // Return the fastest response.
- break RecvResponses
- }
- case <-ctx.Done():
- // If requests time out, userdb is unavailable.
- return ctx.Err() != context.DeadlineExceeded, nil
- }
- }
- // If all sockets are not ok, userdb is unavailable.
- if notOk == len(ss) {
- return false, nil
- }
- return true, um.unmarshalJSON(result.Bytes())
-}
-
-// services enumerates userdb service sockets in dir.
-// If ok is false, io.systemd.UserDatabase service does not exist.
-func (cl userdbClient) services() (s []string, ok bool, err error) {
- var entries []fs.DirEntry
- if entries, err = os.ReadDir(cl.dir); err != nil {
- ok = !os.IsNotExist(err)
- return
- }
- ok = true
- for _, ent := range entries {
- s = append(s, ent.Name())
- }
- return
-}
-
-// query looks up users/groups on the io.systemd.UserDatabase VARLINK interface.
-// If the multiplexer service is available, the call is sent only to it.
-// Otherwise, the call is sent simultaneously to all UserDatabase services in cl.dir.
-// The fastest reply is read and parsed. All other requests are cancelled.
-// If the service is unavailable, the first result is false.
-// The service is considered unavailable if the requests time-out as well.
-func (cl userdbClient) query(ctx context.Context, call *userdbCall, um jsonUnmarshaler) (bool, error) {
- services := []string{svcMultiplexer}
- if _, err := os.Stat(cl.dir + "/" + svcMultiplexer); err != nil {
- // No mux service so call all available services.
- var ok bool
- if services, ok, err = cl.services(); !ok || err != nil {
- return ok, err
- }
- }
- call.fastest = true
- if ok, err := cl.queryMany(ctx, services, call, um); !ok || err != nil {
- return ok, err
- }
- return true, nil
-}
-
-type jsonUnmarshaler interface {
- unmarshalJSON([]byte) error
-}
-
-func isSpace(c byte) bool {
- return c == ' ' || c == '\t' || c == '\r' || c == '\n'
-}
-
-// findElementStart returns a slice of r that starts at the next JSON element.
-// It skips over valid JSON space characters and checks for the colon separator.
-func findElementStart(r []byte) ([]byte, error) {
- var idx int
- var b byte
- colon := byte(':')
- var seenColon bool
- for idx, b = range r {
- if isSpace(b) {
- continue
- }
- if !seenColon && b == colon {
- seenColon = true
- continue
- }
- // Spotted colon and b is not a space, so value starts here.
- if seenColon {
- break
- }
- return nil, errors.New("expected colon, got invalid character: " + string(b))
- }
- if !seenColon {
- return nil, errors.New("expected colon, got end of input")
- }
- return r[idx:], nil
-}
-
-// parseJSONString reads a JSON string from r.
-func parseJSONString(r []byte) (string, error) {
- r, err := findElementStart(r)
- if err != nil {
- return "", err
- }
- // Smallest valid string is `""`.
- if l := len(r); l < 2 {
- return "", errors.New("unexpected end of input")
- } else if l == 2 {
- if bytes.Equal(r, []byte(`""`)) {
- return "", nil
- }
- return "", errors.New("invalid string")
- }
-
- if c := r[0]; c != '"' {
- return "", errors.New(`expected " got ` + string(c))
- }
- // Advance over opening quote.
- r = r[1:]
-
- var value strings.Builder
- var inEsc bool
- var inUEsc bool
- var strEnds bool
- reader := bytes.NewReader(r)
- for {
- if value.Len() > 4096 {
- return "", errors.New("string too large")
- }
-
- // Parse unicode escape sequences.
- if inUEsc {
- maybeRune := make([]byte, 4)
- n, err := reader.Read(maybeRune)
- if err != nil || n != 4 {
- return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
- }
- prn, err := strconv.ParseUint(string(maybeRune), 16, 32)
- if err != nil {
- return "", fmt.Errorf("invalid unicode escape sequence \\u%s", string(maybeRune))
- }
- rn := rune(prn)
- if !utf16.IsSurrogate(rn) {
- value.WriteRune(rn)
- inUEsc = false
- continue
- }
- // rn maybe a high surrogate; read the low surrogate.
- maybeRune = make([]byte, 6)
- n, err = reader.Read(maybeRune)
- if err != nil || n != 6 || maybeRune[0] != '\\' || maybeRune[1] != 'u' {
- // Not a valid UTF-16 surrogate pair.
- if _, err := reader.Seek(int64(-n), io.SeekCurrent); err != nil {
- return "", err
- }
- // Invalid low surrogate; write the replacement character.
- value.WriteRune(utf8.RuneError)
- } else {
- rn1, err := strconv.ParseUint(string(maybeRune[2:]), 16, 32)
- if err != nil {
- return "", fmt.Errorf("invalid unicode escape sequence %s", string(maybeRune))
- }
- // Check if rn and rn1 are valid UTF-16 surrogate pairs.
- if dec := utf16.DecodeRune(rn, rune(rn1)); dec != utf8.RuneError {
- n = utf8.EncodeRune(maybeRune, dec)
- // Write the decoded rune.
- value.Write(maybeRune[:n])
- }
- }
- inUEsc = false
- continue
- }
-
- if inEsc {
- b, err := reader.ReadByte()
- if err != nil {
- return "", err
- }
- switch b {
- case 'b':
- value.WriteByte('\b')
- case 'f':
- value.WriteByte('\f')
- case 'n':
- value.WriteByte('\n')
- case 'r':
- value.WriteByte('\r')
- case 't':
- value.WriteByte('\t')
- case 'u':
- inUEsc = true
- case '/':
- value.WriteByte('/')
- case '\\':
- value.WriteByte('\\')
- case '"':
- value.WriteByte('"')
- default:
- return "", errors.New("unexpected character in escape sequence " + string(b))
- }
- inEsc = false
- continue
- } else {
- rn, _, err := reader.ReadRune()
- if err != nil {
- if err == io.EOF {
- break
- }
- return "", err
- }
- if rn == '\\' {
- inEsc = true
- continue
- }
- if rn == '"' {
- // String ends on un-escaped quote.
- strEnds = true
- break
- }
- value.WriteRune(rn)
- }
- }
- if !strEnds {
- return "", errors.New("unexpected end of input")
- }
- return value.String(), nil
-}
-
-// parseJSONInt64 reads a 64 bit integer from r.
-func parseJSONInt64(r []byte) (int64, error) {
- r, err := findElementStart(r)
- if err != nil {
- return 0, err
- }
- var num strings.Builder
- for _, b := range r {
- // int64 max is 19 digits long.
- if num.Len() == 20 {
- return 0, errors.New("number too large")
- }
- if strings.ContainsRune("0123456789", rune(b)) {
- num.WriteByte(b)
- } else {
- break
- }
- }
- n, err := strconv.ParseInt(num.String(), 10, 64)
- return int64(n), err
-}
-
-// parseJSONBoolean reads a boolean from r.
-func parseJSONBoolean(r []byte) (bool, error) {
- r, err := findElementStart(r)
- if err != nil {
- return false, err
- }
- if bytes.HasPrefix(r, []byte("true")) {
- return true, nil
- }
- if bytes.HasPrefix(r, []byte("false")) {
- return false, nil
- }
- return false, errors.New("unable to parse boolean value")
-}
-
-type groupRecord struct {
- groupName string
- gid int64
-}
-
-func (g *groupRecord) unmarshalJSON(data []byte) error {
- var (
- kGroupName = []byte(`"groupName"`)
- kGid = []byte(`"gid"`)
- )
- if i := bytes.Index(data, kGroupName); i != -1 {
- groupname, err := parseJSONString(data[i+len(kGroupName):])
- if err != nil {
- return err
- }
- g.groupName = groupname
- }
- if i := bytes.Index(data, kGid); i != -1 {
- gid, err := parseJSONInt64(data[i+len(kGid):])
- if err != nil {
- return err
- }
- g.gid = gid
- }
- return nil
-}
-
-// queryGroupDb queries the userdb interface for a gid, groupname, or both.
-func (cl userdbClient) queryGroupDb(ctx context.Context, gid *int64, groupname string) (*Group, bool, error) {
- group := groupRecord{}
- request := userdbCall{
- method: mGetGroupRecord,
- parameters: callParameters{gid: gid, groupName: groupname},
- }
- if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
- return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
- }
- return &Group{
- Name: group.groupName,
- Gid: strconv.FormatInt(group.gid, 10),
- }, true, nil
-}
-
-type userRecord struct {
- userName string
- realName string
- uid int64
- gid int64
- homeDirectory string
-}
-
-func (u *userRecord) unmarshalJSON(data []byte) error {
- var (
- kUserName = []byte(`"userName"`)
- kRealName = []byte(`"realName"`)
- kUid = []byte(`"uid"`)
- kGid = []byte(`"gid"`)
- kHomeDirectory = []byte(`"homeDirectory"`)
- )
- if i := bytes.Index(data, kUserName); i != -1 {
- username, err := parseJSONString(data[i+len(kUserName):])
- if err != nil {
- return err
- }
- u.userName = username
- }
- if i := bytes.Index(data, kRealName); i != -1 {
- realname, err := parseJSONString(data[i+len(kRealName):])
- if err != nil {
- return err
- }
- u.realName = realname
- }
- if i := bytes.Index(data, kUid); i != -1 {
- uid, err := parseJSONInt64(data[i+len(kUid):])
- if err != nil {
- return err
- }
- u.uid = uid
- }
- if i := bytes.Index(data, kGid); i != -1 {
- gid, err := parseJSONInt64(data[i+len(kGid):])
- if err != nil {
- return err
- }
- u.gid = gid
- }
- if i := bytes.Index(data, kHomeDirectory); i != -1 {
- homedir, err := parseJSONString(data[i+len(kHomeDirectory):])
- if err != nil {
- return err
- }
- u.homeDirectory = homedir
- }
- return nil
-}
-
-// queryUserDb queries the userdb interface for a uid, username, or both.
-func (cl userdbClient) queryUserDb(ctx context.Context, uid *int64, username string) (*User, bool, error) {
- user := userRecord{}
- request := userdbCall{
- method: mGetUserRecord,
- parameters: callParameters{
- uid: uid,
- userName: username,
- },
- }
- if ok, err := cl.query(ctx, &request, &user); !ok || err != nil {
- return nil, ok, fmt.Errorf("error querying systemd-userdb user record: %s", err)
- }
- return &User{
- Uid: strconv.FormatInt(user.uid, 10),
- Gid: strconv.FormatInt(user.gid, 10),
- Username: user.userName,
- Name: user.realName,
- HomeDir: user.homeDirectory,
- }, true, nil
-}
-
-func (cl userdbClient) lookupGroup(ctx context.Context, groupname string) (*Group, bool, error) {
- return cl.queryGroupDb(ctx, nil, groupname)
-}
-
-func (cl userdbClient) lookupGroupId(ctx context.Context, id string) (*Group, bool, error) {
- gid, err := strconv.ParseInt(id, 10, 64)
- if err != nil {
- return nil, true, err
- }
- return cl.queryGroupDb(ctx, &gid, "")
-}
-
-func (cl userdbClient) lookupUser(ctx context.Context, username string) (*User, bool, error) {
- return cl.queryUserDb(ctx, nil, username)
-}
-
-func (cl userdbClient) lookupUserId(ctx context.Context, id string) (*User, bool, error) {
- uid, err := strconv.ParseInt(id, 10, 64)
- if err != nil {
- return nil, true, err
- }
- return cl.queryUserDb(ctx, &uid, "")
-}
-
-type memberships struct {
- // Keys are groupNames and values are sets of userNames.
- groupUsers map[string]map[string]struct{}
-}
-
-// unmarshalJSON expects many (userName, groupName) records separated by a null byte.
-// This is used to build a membership map.
-func (m *memberships) unmarshalJSON(data []byte) error {
- if m.groupUsers == nil {
- m.groupUsers = make(map[string]map[string]struct{})
- }
- var (
- kUserName = []byte(`"userName"`)
- kGroupName = []byte(`"groupName"`)
- )
- // Split records by null terminator.
- records := bytes.Split(data, []byte{byte(0)})
- for _, rec := range records {
- if len(rec) == 0 {
- continue
- }
- var groupName string
- var userName string
- var err error
- if i := bytes.Index(rec, kGroupName); i != -1 {
- if groupName, err = parseJSONString(rec[i+len(kGroupName):]); err != nil {
- return err
- }
- }
- if i := bytes.Index(rec, kUserName); i != -1 {
- if userName, err = parseJSONString(rec[i+len(kUserName):]); err != nil {
- return err
- }
- }
- // Associate userName with groupName.
- if groupName != "" && userName != "" {
- if _, ok := m.groupUsers[groupName]; ok {
- m.groupUsers[groupName][userName] = struct{}{}
- } else {
- m.groupUsers[groupName] = map[string]struct{}{userName: {}}
- }
- }
- }
- return nil
-}
-
-func (cl userdbClient) lookupGroupIds(ctx context.Context, username string) ([]string, bool, error) {
- services, ok, err := cl.services()
- if !ok || err != nil {
- return nil, ok, err
- }
- // Fetch group memberships for username.
- var ms memberships
- request := userdbCall{
- method: mGetMemberships,
- parameters: callParameters{userName: username},
- more: true,
- }
- if ok, err := cl.queryMany(ctx, services, &request, &ms); !ok || err != nil {
- return nil, ok, fmt.Errorf("error querying systemd-userdb memberships record: %s", err)
- }
- // Fetch user group gid.
- var group groupRecord
- request = userdbCall{
- method: mGetGroupRecord,
- parameters: callParameters{groupName: username},
- }
- if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
- return nil, ok, err
- }
- gids := []string{strconv.FormatInt(group.gid, 10)}
-
- // Fetch group records for each group.
- for g := range ms.groupUsers {
- var group groupRecord
- request.parameters.groupName = g
- // Query group for gid.
- if ok, err := cl.query(ctx, &request, &group); !ok || err != nil {
- return nil, ok, fmt.Errorf("error querying systemd-userdb group record: %s", err)
- }
- gids = append(gids, strconv.FormatInt(group.gid, 10))
- }
- return gids, true, nil
-}
+++ /dev/null
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build linux
-
-package user
-
-import (
- "bytes"
- "context"
- "errors"
- "reflect"
- "sort"
- "strconv"
- "strings"
- "sync"
- "syscall"
- "testing"
- "time"
- "unicode/utf8"
-)
-
-func TestQueryNoUserdb(t *testing.T) {
- cl := &userdbClient{dir: "/non/existent"}
- if _, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib"); ok {
- t.Fatalf("should fail but lookup has been handled or error is nil: %v", err)
- }
-}
-
-type userdbTestData map[string]udbResponse
-
-type udbResponse struct {
- data []byte
- delay time.Duration
-}
-
-func userdbServer(t *testing.T, sockFn string, data userdbTestData) {
- ready := make(chan struct{})
- go func() {
- if err := serveUserdb(ready, sockFn, data); err != nil {
- t.Error(err)
- }
- }()
- <-ready
-}
-
-func (u userdbTestData) String() string {
- var s strings.Builder
- for k, v := range u {
- s.WriteString("Request:\n")
- s.WriteString(k)
- s.WriteString("\nResponse:\n")
- if v.delay > 0 {
- s.WriteString("Delay: ")
- s.WriteString(v.delay.String())
- s.WriteString("\n")
- }
- s.WriteString("Data:\n")
- s.Write(v.data)
- s.WriteString("\n")
- }
- return s.String()
-}
-
-// serverUserdb is a simple userdb server that replies to VARLINK method calls.
-// A message is sent on the ready channel when the server is ready to accept calls.
-// The server will reply to each request in the data map. If a request is not
-// found in the map, the server will return an error.
-func serveUserdb(ready chan<- struct{}, sockFn string, data userdbTestData) error {
- sockFd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
- if err != nil {
- return err
- }
- defer syscall.Close(sockFd)
- if err := syscall.Bind(sockFd, &syscall.SockaddrUnix{Name: sockFn}); err != nil {
- return err
- }
- if err := syscall.Listen(sockFd, 1); err != nil {
- return err
- }
-
- // Send ready signal.
- ready <- struct{}{}
-
- var srvGroup sync.WaitGroup
-
- srvErrs := make(chan error, len(data))
- for len(data) != 0 {
- nfd, _, err := syscall.Accept(sockFd)
- if err != nil {
- syscall.Close(nfd)
- return err
- }
-
- // Read request.
- buf := make([]byte, 4096)
- n, err := syscall.Read(nfd, buf)
- if err != nil {
- syscall.Close(nfd)
- return err
- }
- if n == 0 {
- // Client went away.
- continue
- }
- if buf[n-1] != 0 {
- syscall.Close(nfd)
- return errors.New("request not null terminated")
- }
- // Remove null terminator.
- buf = buf[:n-1]
- got := string(buf)
-
- // Fetch response for request.
- response, ok := data[got]
- if !ok {
- syscall.Close(nfd)
- msg := "unexpected request:\n" + got + "\n\ndata:\n" + data.String()
- return errors.New(msg)
- }
- delete(data, got)
-
- srvGroup.Add(1)
- go func() {
- defer srvGroup.Done()
- if err := serveClient(nfd, response); err != nil {
- srvErrs <- err
- }
- }()
- }
-
- srvGroup.Wait()
- // Combine serve errors if any.
- if len(srvErrs) > 0 {
- var errs []error
- for err := range srvErrs {
- errs = append(errs, err)
- }
- return errors.Join(errs...)
- }
-
- return nil
-}
-
-func serveClient(fd int, response udbResponse) error {
- defer syscall.Close(fd)
- time.Sleep(response.delay)
- data := response.data
- if len(data) != 0 && data[len(data)-1] != 0 {
- data = append(data, 0)
- }
- written := 0
- for written < len(data) {
- if n, err := syscall.Write(fd, data[written:]); err != nil {
- return err
- } else {
- written += n
- }
- }
- return nil
-}
-
-func TestSlowUserdbLookup(t *testing.T) {
- tmpdir := t.TempDir()
- data := userdbTestData{
- `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
- delay: time.Hour,
- },
- }
- userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
- cl := &userdbClient{dir: tmpdir}
- // Lookup should timeout.
- ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
- defer cancel()
- if _, ok, _ := cl.lookupGroup(ctx, "stdlibcontrib"); ok {
- t.Fatalf("lookup should not be handled but was")
- }
-}
-
-func TestFastestUserdbLookup(t *testing.T) {
- tmpdir := t.TempDir()
- fastData := userdbTestData{
- `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"fast","groupName":"stdlibcontrib"}}`: udbResponse{
- data: []byte(
- `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
- ),
- },
- }
- slowData := userdbTestData{
- `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"slow","groupName":"stdlibcontrib"}}`: udbResponse{
- delay: 50 * time.Millisecond,
- data: []byte(
- `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":182,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
- ),
- },
- }
- userdbServer(t, tmpdir+"/"+"fast", fastData)
- userdbServer(t, tmpdir+"/"+"slow", slowData)
- cl := &userdbClient{dir: tmpdir}
- group, ok, err := cl.lookupGroup(context.Background(), "stdlibcontrib")
- if !ok {
- t.Fatalf("lookup should be handled but was not")
- }
- if err != nil {
- t.Fatalf("lookup should not fail but did: %v", err)
- }
- if group.Gid != "181" {
- t.Fatalf("lookup should return group 181 but returned %s", group.Gid)
- }
-}
-
-func TestUserdbLookupGroup(t *testing.T) {
- tmpdir := t.TempDir()
- data := userdbTestData{
- `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
- data: []byte(
- `{"parameters":{"record":{"groupName":"stdlibcontrib","gid":181,"members":["stdlibcontrib"],"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
- ),
- },
- }
- userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
-
- groupname := "stdlibcontrib"
- want := &Group{
- Name: "stdlibcontrib",
- Gid: "181",
- }
- cl := &userdbClient{dir: tmpdir}
- got, ok, err := cl.lookupGroup(context.Background(), groupname)
- if !ok {
- t.Fatal("lookup should have been handled")
- }
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, want) {
- t.Fatalf("lookupGroup(%s) = %v, want %v", groupname, got, want)
- }
-}
-
-func TestUserdbLookupUser(t *testing.T) {
- tmpdir := t.TempDir()
- data := userdbTestData{
- `{"method":"io.systemd.UserDatabase.GetUserRecord","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"}}`: udbResponse{
- data: []byte(
- `{"parameters":{"record":{"userName":"stdlibcontrib","uid":181,"gid":181,"realName":"Stdlib Contrib","homeDirectory":"/home/stdlibcontrib","status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
- ),
- },
- }
- userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
-
- username := "stdlibcontrib"
- want := &User{
- Uid: "181",
- Gid: "181",
- Username: "stdlibcontrib",
- Name: "Stdlib Contrib",
- HomeDir: "/home/stdlibcontrib",
- }
- cl := &userdbClient{dir: tmpdir}
- got, ok, err := cl.lookupUser(context.Background(), username)
- if !ok {
- t.Fatal("lookup should have been handled")
- }
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, want) {
- t.Fatalf("lookupUser(%s) = %v, want %v", username, got, want)
- }
-}
-
-func TestUserdbLookupGroupIds(t *testing.T) {
- tmpdir := t.TempDir()
- data := userdbTestData{
- `{"method":"io.systemd.UserDatabase.GetMemberships","parameters":{"service":"io.systemd.Multiplexer","userName":"stdlibcontrib"},"more":true}`: udbResponse{
- data: []byte(
- `{"parameters":{"userName":"stdlibcontrib","groupName":"stdlib"},"continues":true}` + "\x00" + `{"parameters":{"userName":"stdlibcontrib","groupName":"contrib"}}`,
- ),
- },
- // group records
- `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlibcontrib"}}`: udbResponse{
- data: []byte(
- `{"parameters":{"record":{"groupName":"stdlibcontrib","members":["stdlibcontrib"],"gid":181,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
- ),
- },
- `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"stdlib"}}`: udbResponse{
- data: []byte(
- `{"parameters":{"record":{"groupName":"stdlib","members":["stdlibcontrib"],"gid":182,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
- ),
- },
- `{"method":"io.systemd.UserDatabase.GetGroupRecord","parameters":{"service":"io.systemd.Multiplexer","groupName":"contrib"}}`: udbResponse{
- data: []byte(
- `{"parameters":{"record":{"groupName":"contrib","members":["stdlibcontrib"],"gid":183,"status":{"ecb5a44f1a5846ad871566e113bf8937":{"service":"io.systemd.NameServiceSwitch"}}},"incomplete":false}}`,
- ),
- },
- }
- userdbServer(t, tmpdir+"/"+svcMultiplexer, data)
-
- username := "stdlibcontrib"
- want := []string{"181", "182", "183"}
- cl := &userdbClient{dir: tmpdir}
- got, ok, err := cl.lookupGroupIds(context.Background(), username)
- if !ok {
- t.Fatal("lookup should have been handled")
- }
- if err != nil {
- t.Fatal(err)
- }
- // Result order is not specified so sort it.
- sort.Strings(got)
- if !reflect.DeepEqual(got, want) {
- t.Fatalf("lookupGroupIds(%s) = %v, want %v", username, got, want)
- }
-}
-
-var findElementStartTestCases = []struct {
- in []byte
- want []byte
- err bool
-}{
- {in: []byte(`:`), want: []byte(``)},
- {in: []byte(`: `), want: []byte(``)},
- {in: []byte(`:"foo"`), want: []byte(`"foo"`)},
- {in: []byte(` :"foo"`), want: []byte(`"foo"`)},
- {in: []byte(` 1231 :"foo"`), err: true},
- {in: []byte(``), err: true},
- {in: []byte(`"foo"`), err: true},
- {in: []byte(`foo`), err: true},
-}
-
-func TestFindElementStart(t *testing.T) {
- for i, tc := range findElementStartTestCases {
- t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
- got, err := findElementStart(tc.in)
- if tc.err && err == nil {
- t.Errorf("want err for findElementStart(%s), got nil", tc.in)
- }
- if !tc.err {
- if err != nil {
- t.Errorf("findElementStart(%s) unexpected error: %s", tc.in, err.Error())
- }
- if !bytes.Contains(tc.in, got) {
- t.Errorf("%s should contain %s but does not", tc.in, got)
- }
- }
- })
- }
-}
-
-func FuzzFindElementStart(f *testing.F) {
- for _, tc := range findElementStartTestCases {
- if !tc.err {
- f.Add(tc.in)
- }
- }
- f.Fuzz(func(t *testing.T, b []byte) {
- if out, err := findElementStart(b); err == nil && !bytes.Contains(b, out) {
- t.Errorf("%s, %v", out, err)
- }
- })
-}
-
-var parseJSONStringTestCases = []struct {
- in []byte
- want string
- err bool
-}{
- {in: []byte(`:""`)},
- {in: []byte(`:"\n"`), want: "\n"},
- {in: []byte(`: "\""`), want: "\""},
- {in: []byte(`:"\t \\"`), want: "\t \\"},
- {in: []byte(`:"\\\\"`), want: `\\`},
- {in: []byte(`::`), err: true},
- {in: []byte(`""`), err: true},
- {in: []byte(`"`), err: true},
- {in: []byte(":\"0\xE5"), err: true},
- {in: []byte{':', '"', 0xFE, 0xFE, 0xFF, 0xFF, '"'}, want: "\uFFFD\uFFFD\uFFFD\uFFFD"},
- {in: []byte(`:"\u0061a"`), want: "aa"},
- {in: []byte(`:"\u0159\u0170"`), want: "řŰ"},
- {in: []byte(`:"\uD800\uDC00"`), want: "\U00010000"},
- {in: []byte(`:"\uD800"`), want: "\uFFFD"},
- {in: []byte(`:"\u000"`), err: true},
- {in: []byte(`:"\u00MF"`), err: true},
- {in: []byte(`:"\uD800\uDC0"`), err: true},
-}
-
-func TestParseJSONString(t *testing.T) {
- for i, tc := range parseJSONStringTestCases {
- t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
- got, err := parseJSONString(tc.in)
- if tc.err && err == nil {
- t.Errorf("want err for parseJSONString(%s), got nil", tc.in)
- }
- if !tc.err {
- if err != nil {
- t.Errorf("parseJSONString(%s) unexpected error: %s", tc.in, err.Error())
- }
- if tc.want != got {
- t.Errorf("parseJSONString(%s) = %s, want %s", tc.in, got, tc.want)
- }
- }
- })
- }
-}
-
-func FuzzParseJSONString(f *testing.F) {
- for _, tc := range parseJSONStringTestCases {
- f.Add(tc.in)
- }
- f.Fuzz(func(t *testing.T, b []byte) {
- if out, err := parseJSONString(b); err == nil && !utf8.ValidString(out) {
- t.Errorf("parseJSONString(%s) = %s, invalid string", b, out)
- }
- })
-}
-
-var parseJSONInt64TestCases = []struct {
- in []byte
- want int64
- err bool
-}{
- {in: []byte(":1235"), want: 1235},
- {in: []byte(": 123"), want: 123},
- {in: []byte(":0")},
- {in: []byte(":5012313123131231"), want: 5012313123131231},
- {in: []byte("1231"), err: true},
-}
-
-func TestParseJSONInt64(t *testing.T) {
- for i, tc := range parseJSONInt64TestCases {
- t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
- got, err := parseJSONInt64(tc.in)
- if tc.err && err == nil {
- t.Errorf("want err for parseJSONInt64(%s), got nil", tc.in)
- }
- if !tc.err {
- if err != nil {
- t.Errorf("parseJSONInt64(%s) unexpected error: %s", tc.in, err.Error())
- }
- if tc.want != got {
- t.Errorf("parseJSONInt64(%s) = %d, want %d", tc.in, got, tc.want)
- }
- }
- })
- }
-}
-
-func FuzzParseJSONInt64(f *testing.F) {
- for _, tc := range parseJSONInt64TestCases {
- f.Add(tc.in)
- }
- f.Fuzz(func(t *testing.T, b []byte) {
- if out, err := parseJSONInt64(b); err == nil &&
- !bytes.Contains(b, []byte(strconv.FormatInt(out, 10))) {
- t.Errorf("parseJSONInt64(%s) = %d, %v", b, out, err)
- }
- })
-}
-
-var parseJSONBooleanTestCases = []struct {
- in []byte
- want bool
- err bool
-}{
- {in: []byte(": true "), want: true},
- {in: []byte(":true "), want: true},
- {in: []byte(": false "), want: false},
- {in: []byte(":false "), want: false},
- {in: []byte("true"), err: true},
- {in: []byte("false"), err: true},
- {in: []byte("foo"), err: true},
-}
-
-func TestParseJSONBoolean(t *testing.T) {
- for i, tc := range parseJSONBooleanTestCases {
- t.Run("#"+strconv.Itoa(i), func(t *testing.T) {
- got, err := parseJSONBoolean(tc.in)
- if tc.err && err == nil {
- t.Errorf("want err for parseJSONBoolean(%s), got nil", tc.in)
- }
- if !tc.err {
- if err != nil {
- t.Errorf("parseJSONBoolean(%s) unexpected error: %s", tc.in, err.Error())
- }
- if tc.want != got {
- t.Errorf("parseJSONBoolean(%s) = %t, want %t", tc.in, got, tc.want)
- }
- }
- })
- }
-}
-
-func FuzzParseJSONBoolean(f *testing.F) {
- for _, tc := range parseJSONBooleanTestCases {
- f.Add(tc.in)
- }
- f.Fuzz(func(t *testing.T, b []byte) {
- if out, err := parseJSONBoolean(b); err == nil && !bytes.Contains(b, []byte(strconv.FormatBool(out))) {
- t.Errorf("parseJSONBoolean(%s) = %t, %v", b, out, err)
- }
- })
-}
+++ /dev/null
-// Copyright 2023 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !linux
-
-package user
-
-import "context"
-
-func (cl userdbClient) lookupGroup(_ context.Context, _ string) (*Group, bool, error) {
- return nil, false, nil
-}
-
-func (cl userdbClient) lookupGroupId(_ context.Context, _ string) (*Group, bool, error) {
- return nil, false, nil
-}
-
-func (cl userdbClient) lookupUser(_ context.Context, _ string) (*User, bool, error) {
- return nil, false, nil
-}
-
-func (cl userdbClient) lookupUserId(_ context.Context, _ string) (*User, bool, error) {
- return nil, false, nil
-}
-
-func (cl userdbClient) lookupGroupIds(_ context.Context, _ string) ([]string, bool, error) {
- return nil, false, nil
-}