// Package zookeeper contains the ZooKeeper store implementation.
package zookeeper

import (
	"context"
	"errors"
	"strings"
	"time"

	"github.com/go-zookeeper/zk"
	"github.com/kvtools/valkeyrie"
	"github.com/kvtools/valkeyrie/store"
)

// StoreName the name of the store.
const StoreName = "zookeeper"

const (
	// SOH control character.
	SOH = "\x01"

	defaultTimeout = 10 * time.Second

	syncRetryLimit = 5
)

// registers Zookeeper to Valkeyrie.
func init() {
	valkeyrie.Register(StoreName, newStore)
}

// Config the Zookeeper configuration.
type Config struct {
	ConnectionTimeout time.Duration
	Username          string
	Password          string
	MaxBufferSize     int
}

func newStore(ctx context.Context, endpoints []string, options valkeyrie.Config) (store.Store, error) {
	cfg, ok := options.(*Config)
	if !ok && options != nil {
		return nil, &store.InvalidConfigurationError{Store: StoreName, Config: options}
	}

	return New(ctx, endpoints, cfg)
}

// Store implements the store.Store interface.
type Store struct {
	timeout time.Duration
	client  *zk.Conn
}

// New creates a new Zookeeper client.
func New(_ context.Context, endpoints []string, options *Config) (store.Store, error) {
	s := &Store{}
	s.timeout = defaultTimeout

	// Set options.
	if options != nil {
		if options.ConnectionTimeout != 0 {
			s.setTimeout(options.ConnectionTimeout)
		}
	}

	// Connect to Zookeeper.
	var conn *zk.Conn
	var err error
	if options != nil && options.MaxBufferSize > 0 {
		conn, _, err = zk.Connect(endpoints, s.timeout, zk.WithMaxConnBufferSize(options.MaxBufferSize))
	} else {
		conn, _, err = zk.Connect(endpoints, s.timeout)
	}
	if err != nil {
		return nil, err
	}

	s.client = conn

	if options != nil && options.Username != "" {
		err := s.client.AddAuth("digest", []byte(options.Username+":"+options.Password))
		if err != nil {
			return nil, err
		}
	}

	return s, nil
}

// setTimeout sets the timeout for connecting to Zookeeper.
func (s *Store) setTimeout(timeout time.Duration) {
	s.timeout = timeout
}

// Get the value at "key".
// Returns the last modified index to use in conjunction to Atomic calls.
func (s *Store) Get(_ context.Context, key string, _ *store.ReadOptions) (pair *store.KVPair, err error) {
	resp, meta, err := s.get(key)
	if err != nil {
		return nil, err
	}

	pair = &store.KVPair{
		Key:       key,
		Value:     resp,
		LastIndex: uint64(meta.Version),
	}

	return pair, nil
}

// createFullPath creates the entire path for a directory
// that does not exist and sets the value of the last znode to data.
func (s *Store) createFullPath(path []string, data []byte, ephemeral bool) error {
	for i := 1; i <= len(path); i++ {
		newPath := "/" + strings.Join(path[:i], "/")

		if i == len(path) {
			flag := 0
			if ephemeral {
				flag = zk.FlagEphemeral
			}
			_, err := s.client.Create(newPath, data, int32(flag), zk.WorldACL(zk.PermAll))
			return err
		}

		_, err := s.client.Create(newPath, []byte{}, 0, zk.WorldACL(zk.PermAll))
		if err != nil {
			// Skip if node already exists.
			if !errors.Is(err, zk.ErrNodeExists) {
				return err
			}
		}
	}
	return nil
}

// Put a value at "key".
func (s *Store) Put(ctx context.Context, key string, value []byte, opts *store.WriteOptions) error {
	fkey := normalize(key)

	exists, err := s.Exists(ctx, key, nil)
	if err != nil {
		return err
	}

	if !exists {
		if opts != nil && opts.TTL > 0 {
			_ = s.createFullPath(store.SplitKey(strings.TrimSuffix(key, "/")), value, true)
		} else {
			_ = s.createFullPath(store.SplitKey(strings.TrimSuffix(key, "/")), value, false)
		}
	} else {
		_, err = s.client.Set(fkey, value, -1)
	}

	return err
}

// Delete a value at "key".
func (s *Store) Delete(_ context.Context, key string) error {
	err := s.client.Delete(normalize(key), -1)
	if errors.Is(err, zk.ErrNoNode) {
		return store.ErrKeyNotFound
	}
	return err
}

// Exists checks if the key exists inside the store.
func (s *Store) Exists(_ context.Context, key string, _ *store.ReadOptions) (bool, error) {
	exists, _, err := s.client.Exists(normalize(key))
	if err != nil {
		return false, err
	}
	return exists, nil
}

// Watch for changes on a "key".
// It returns a channel that will receive changes or pass on errors.
// Upon creation, the current value will first be sent to the channel.
// Providing a non-nil stopCh can be used to stop watching.
func (s *Store) Watch(ctx context.Context, key string, _ *store.ReadOptions) (<-chan *store.KVPair, error) {
	// Catch zk notifications and fire changes into the channel.
	watchCh := make(chan *store.KVPair)
	go func() {
		defer close(watchCh)

		fireEvt := true
		for {
			resp, meta, eventCh, err := s.getW(key)
			if err != nil {
				return
			}
			if fireEvt {
				watchCh <- &store.KVPair{
					Key:       key,
					Value:     resp,
					LastIndex: uint64(meta.Version),
				}
			}
			select {
			case e := <-eventCh:
				// Only fire an event if the data in the node changed.
				// Simply reset the watch if this is any other event
				// (e.g. a session event).
				fireEvt = e.Type == zk.EventNodeDataChanged
			case <-ctx.Done():
				// There is no way to stop GetW so just quit.
				return
			}
		}
	}()

	return watchCh, nil
}

// WatchTree watches for changes on a "directory".
// It returns a channel that will receive changes or pass on errors.
// Upon creating a watch, the current children values will be sent to the channel.
// Providing a non-nil stopCh can be used to stop watching.
func (s *Store) WatchTree(ctx context.Context, directory string, opts *store.ReadOptions) (<-chan []*store.KVPair, error) {
	// Catch zk notifications and fire changes into the channel.
	watchCh := make(chan []*store.KVPair)

	go func() {
		defer close(watchCh)

		fireEvt := true
		for {
		WATCH:
			keys, _, eventCh, err := s.client.ChildrenW(normalize(directory))
			if err != nil {
				return
			}
			if fireEvt {
				kvs, err := s.getListWithPath(ctx, directory, keys, opts)
				if err != nil {
					// Failed to get values for one or more of the keys,
					// the list may be out of date so try again.
					goto WATCH
				}
				watchCh <- kvs
			}
			select {
			case e := <-eventCh:
				// Only fire an event if the children have changed.
				// Simply reset the watch if this is any other event
				// (e.g. a session event).
				fireEvt = e.Type == zk.EventNodeChildrenChanged
			case <-ctx.Done():
				// There is no way to stop ChildrenW so just quit.
				return
			}
		}
	}()

	return watchCh, nil
}

// listChildren lists the direct children of a directory.
func (s *Store) listChildren(directory string) ([]string, error) {
	children, _, err := s.client.Children(normalize(directory))
	if err != nil {
		if errors.Is(err, zk.ErrNoNode) {
			return nil, store.ErrKeyNotFound
		}
		return nil, err
	}
	return children, nil
}

// listChildrenRecursive lists the children of a directory
// as well as all the descending children from sub-folders in a recursive fashion.
func (s *Store) listChildrenRecursive(list *[]string, directory string) error {
	children, err := s.listChildren(directory)
	if err != nil {
		return err
	}

	// We reached a leaf.
	if len(children) == 0 {
		return nil
	}

	for _, c := range children {
		c = strings.TrimSuffix(directory, "/") + "/" + c
		err := s.listChildrenRecursive(list, c)
		if err != nil && !errors.Is(err, zk.ErrNoChildrenForEphemerals) {
			return err
		}
		*list = append(*list, c)
	}

	return nil
}

// List child nodes of a given directory.
func (s *Store) List(ctx context.Context, directory string, opts *store.ReadOptions) ([]*store.KVPair, error) {
	children := make([]string, 0)

	err := s.listChildrenRecursive(&children, directory)
	if err != nil {
		return nil, err
	}

	kvs, err := s.getList(ctx, children, opts)
	if err != nil {
		// If node is not found: List is out of date, retry.
		if errors.Is(err, store.ErrKeyNotFound) {
			return s.List(ctx, directory, opts)
		}
		return nil, err
	}

	return kvs, nil
}

// DeleteTree deletes a range of keys under a given directory.
func (s *Store) DeleteTree(_ context.Context, directory string) error {
	children, err := s.listChildren(directory)
	if err != nil {
		return err
	}

	var reqs []interface{}

	for _, c := range children {
		reqs = append(reqs, &zk.DeleteRequest{
			Path:    normalize(directory + "/" + c),
			Version: -1,
		})
	}

	_, err = s.client.Multi(reqs...)
	return err
}

// AtomicPut puts a value at "key" if the key has not been modified in the meantime,
// throws an error if this is the case.
func (s *Store) AtomicPut(_ context.Context, key string, value []byte, previous *store.KVPair, _ *store.WriteOptions) (bool, *store.KVPair, error) {
	if previous != nil {
		meta, err := s.client.Set(normalize(key), value, int32(previous.LastIndex))
		if err != nil {
			// Compare Failed.
			if errors.Is(err, zk.ErrBadVersion) {
				return false, nil, store.ErrKeyModified
			}
			return false, nil, err
		}

		pair := &store.KVPair{
			Key:       key,
			Value:     value,
			LastIndex: uint64(meta.Version),
		}

		return true, pair, nil
	}

	// Interpret previous == nil as create operation.
	_, err := s.client.Create(normalize(key), value, 0, zk.WorldACL(zk.PermAll))
	if err != nil { //nolint:nestif // require a deep refactor.
		// Node Exists error (when previous nil).
		if errors.Is(err, zk.ErrNodeExists) {
			return false, nil, store.ErrKeyExists
		}

		// Unhandled error.
		if !errors.Is(err, zk.ErrNoNode) {
			return false, nil, err
		}

		// Directory does not exist.

		// Create the directory.
		parts := store.SplitKey(strings.TrimSuffix(key, "/"))
		parts = parts[:len(parts)-1]

		err = s.createFullPath(parts, []byte{}, false)
		if err != nil {
			// Failed to create the directory.
			return false, nil, err
		}

		// Create the node.
		_, err := s.client.Create(normalize(key), value, 0, zk.WorldACL(zk.PermAll))
		if err != nil {
			// Node exist error (when previous nil).
			if errors.Is(err, zk.ErrNodeExists) {
				return false, nil, store.ErrKeyExists
			}
			return false, nil, err
		}
	}

	pair := &store.KVPair{
		Key:       key,
		Value:     value,
		LastIndex: 0, // Newly created nodes have version 0.
	}

	return true, pair, nil
}

// AtomicDelete deletes a value at "key" if the key has not been modified in the meantime,
// throws an error if this is the case.
func (s *Store) AtomicDelete(_ context.Context, key string, previous *store.KVPair) (bool, error) {
	if previous == nil {
		return false, store.ErrPreviousNotSpecified
	}

	err := s.client.Delete(normalize(key), int32(previous.LastIndex))
	if err != nil {
		// Key not found.
		if errors.Is(err, zk.ErrNoNode) {
			return false, store.ErrKeyNotFound
		}
		// Compare failed.
		if errors.Is(err, zk.ErrBadVersion) {
			return false, store.ErrKeyModified
		}
		// General store error.
		return false, err
	}
	return true, nil
}

// NewLock returns a handle to a lock struct which can be used to provide mutual exclusion on a key.
func (s *Store) NewLock(_ context.Context, key string, opts *store.LockOptions) (lock store.Locker, err error) {
	value := []byte("")

	// Apply options.
	if opts != nil {
		if opts.Value != nil {
			value = opts.Value
		}
	}

	lock = &zookeeperLock{
		client: s.client,
		key:    normalize(key),
		value:  value,
		lock:   zk.NewLock(s.client, normalize(key), zk.WorldACL(zk.PermAll)),
	}

	return lock, err
}

// Close closes the client connection.
func (s *Store) Close() error {
	s.client.Close()
	return nil
}

func (s *Store) get(key string) ([]byte, *zk.Stat, error) {
	var resp []byte
	var meta *zk.Stat
	var err error

	// To guard against older versions of valkeyrie
	// creating and writing to znodes non-atomically,
	// we try to resync few times if we read SOH or an empty string.
	for i := 0; i <= syncRetryLimit; i++ {
		resp, meta, err = s.client.Get(normalize(key))

		if err != nil {
			if errors.Is(err, zk.ErrNoNode) {
				return nil, nil, store.ErrKeyNotFound
			}
			return nil, nil, err
		}

		if string(resp) != SOH && len(resp) != 0 {
			return resp, meta, nil
		}

		if i < syncRetryLimit {
			if _, err = s.client.Sync(normalize(key)); err != nil {
				return nil, nil, err
			}
		}
	}
	return resp, meta, nil
}

func (s *Store) getW(key string) ([]byte, *zk.Stat, <-chan zk.Event, error) {
	var resp []byte
	var meta *zk.Stat
	var ech <-chan zk.Event
	var err error

	// To guard against older versions of valkeyrie
	// creating and writing to znodes non-atomically,
	// We try to resync few times if we read SOH or an empty string.
	for i := 0; i <= syncRetryLimit; i++ {
		resp, meta, ech, err = s.client.GetW(normalize(key))

		if err != nil {
			if errors.Is(err, zk.ErrNoNode) {
				return nil, nil, nil, store.ErrKeyNotFound
			}
			return nil, nil, nil, err
		}

		if string(resp) != SOH && len(resp) != 0 {
			return resp, meta, ech, nil
		}

		if i < syncRetryLimit {
			if _, err = s.client.Sync(normalize(key)); err != nil {
				return nil, nil, nil, err
			}
		}
	}
	return resp, meta, ech, nil
}

// getListWithPath gets the key/value pairs for a list of keys under a given path.
//
// This is generally used when we get a list of child keys which
// are stripped out of their path (for example when using ChildrenW).
func (s *Store) getListWithPath(ctx context.Context, path string, keys []string, opts *store.ReadOptions) ([]*store.KVPair, error) {
	var kvs []*store.KVPair

	for _, key := range keys {
		pair, err := s.Get(ctx, strings.TrimSuffix(path, "/")+normalize(key), opts)
		if err != nil {
			return nil, err
		}

		kvs = append(kvs, &store.KVPair{
			Key:       key,
			Value:     pair.Value,
			LastIndex: pair.LastIndex,
		})
	}

	return kvs, nil
}

// getList returns key/value pairs from a list of keys.
//
// This is generally used when we have a full list of keys with their full path included.
func (s *Store) getList(ctx context.Context, keys []string, _ *store.ReadOptions) ([]*store.KVPair, error) {
	var kvs []*store.KVPair

	for _, key := range keys {
		pair, err := s.Get(ctx, strings.TrimSuffix(key, "/"), nil)
		if err != nil {
			return nil, err
		}

		kvs = append(kvs, &store.KVPair{
			Key:       key,
			Value:     pair.Value,
			LastIndex: pair.LastIndex,
		})
	}

	return kvs, nil
}

type zookeeperLock struct {
	client *zk.Conn
	lock   *zk.Lock
	key    string
	value  []byte
}

// Lock attempts to acquire the lock and blocks while doing so.
// It returns a channel that is closed if our lock is lost or if an error occurs.
func (l *zookeeperLock) Lock(ctx context.Context) (<-chan struct{}, error) {
	err := l.lock.Lock()

	lostCh := make(chan struct{})
	if err == nil {
		// We hold the lock, we can set our value.
		_, err = l.client.Set(l.key, l.value, -1)
		if err == nil {
			go l.monitorLock(ctx, lostCh)
		}
	}

	return lostCh, err
}

// Unlock the "key".
// Calling unlock while not holding the lock will throw an error.
func (l *zookeeperLock) Unlock(_ context.Context) error {
	return l.lock.Unlock()
}

func (l *zookeeperLock) monitorLock(ctx context.Context, lostCh chan struct{}) {
	defer close(lostCh)

	for {
		_, _, eventCh, err := l.client.GetW(l.key)
		if err != nil {
			// We failed to set watch, relinquish the lock.
			return
		}
		select {
		case e := <-eventCh:
			if e.Type == zk.EventNotWatching ||
				(e.Type == zk.EventSession && e.State == zk.StateExpired) {
				// Either the session has been closed and our watch has been
				// invalidated or the session has expired.
				return
			} else if e.Type == zk.EventNodeDataChanged {
				// Someone else has written to the lock node and believes that they have the lock.
				return
			}
		case <-ctx.Done():
			// The caller has requested that we relinquish our lock.
			return
		}
	}
}

// normalize the key for usage in Zookeeper.
func normalize(key string) string {
	return strings.TrimSuffix("/"+strings.TrimPrefix(key, "/"), "/")
}
