// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "contrib_ops/cpu/bert/attention_common.h"

namespace onnxruntime {
namespace contrib {

// Parameters deduced from node attributes and inputs/outputs.
struct AttentionParameters {
  int batch_size;
  int sequence_length;
  int kv_sequence_length;     // input sequence length of K or V
  int past_sequence_length;   // sequence length in past state of K or V
  int total_sequence_length;  // total sequence length of K or V
  int max_sequence_length;    // max sequence length from 4D mask
  int input_hidden_size;      // first dimension of weights for input projection
  int hidden_size;            // hidden size of Q or K
  int head_size;              // hidden size per head of Q or K
  int v_hidden_size;          // hidden size of V
  int v_head_size;            // hidden size per head of V
  int num_heads;
  int num_splits;      // number of splits for splitkv
  int rotary_dim = 0;  // rotary embedding dimension
  int beam_width;
  bool is_unidirectional = false;
  bool past_present_share_buffer = false;
  bool is_packed_qkv = false;  // whether qkv is packed
  bool do_rotary = false;
  bool broadcast_attn_bias_dim_0 = false;
  bool broadcast_attn_bias_dim_1 = false;
  float mask_filter_value;
  float scale;
  bool use_tf32 = false;
  bool is_output_bnsh = false;  // whether the output format is BNSH
  AttentionMaskType mask_type;
  AttentionQkvFormat qkv_format;
};

// Parameters deduced from node attributes and inputs/outputs.
struct PackedAttentionParameters : AttentionParameters {
  int token_count;
};

struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters {
  int beam_width = 1;

  // Only NeoX style rotary embedding is supported
  int t_step = 0;

  // Weather to use multihead attention(excludes matmul and bias)
  bool is_mha = false;
  bool is_cross_attention = false;

  // Useful to better use global memory bandwidth on certain CUDA architectures.
  // Turned off by default for now until we fully understand performance implications
  // for all types of workloads.
  // Can be turned on by appropriate environment variable (see attention_common.h).
  bool kv_data_in_flight = false;

  void* q = nullptr;
  void* q_bias = nullptr;

  void* k = nullptr;
  void* k_bias = nullptr;

  void* v = nullptr;
  void* v_bias = nullptr;

  void* attention_bias = nullptr;

  void* k_cache = nullptr;
  void* v_cache = nullptr;

  void* out = nullptr;
  void* out_qk = nullptr;

  const int32_t* cache_indir = nullptr;
  const int32_t* mask = nullptr;  // [B, total_sequence_length]
};

// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters : AttentionParameters {
  int kv_num_heads;             // number of heads of key or value
  int kv_hidden_size;           // hidden size of key or value
  int seqlen_past_kv_cache;     // sequence length of past kv tensor
  int seqlen_present_kv_cache;  // sequence length of present kv tensor
  int local_window_size;        // Mask out tokens prior to total_sequence_length - local_window_size
  bool is_subsequent_prompt;    // indicates whether we have past context and seqlen > 1
  bool is_first_prompt;         // indicates whether this is first decoding step
  bool rotary_interleaved;
  bool use_smooth_softmax;
  float softcap;
  AttentionQkvFormat past_kv_format;
  int zeros_count;
  int* zero_ptr;
};

// Parameters deduced from node attributes and inputs/outputs.
struct PagedAttentionParameters : AttentionParameters {
  int kv_num_heads;            // number of heads of key or value
  int kv_hidden_size;          // hidden size of key or value
  int token_count;             // number of tokens in packed query
  int block_size;              // block size for kv cache
  int max_num_blocks_per_seq;  // max number of blocks per sequence for kv cache
  int num_blocks;              // number of blocks in kv cache
  int local_window_size;       // The window size includes new token. It only includes tokens on the left side.
  bool rotary_interleaved;
  float softcap;
};

// Parameters for sparse attention.
struct SparseAttentionParameters : AttentionParameters {
  int kv_hidden_size;              // hidden size of key or value
  int kv_num_heads;                // number of heads of key or value
  bool do_rotary;                  // whether to use rotary embedding
  bool rotary_interleaved;         // whether to use interleaved rotary embedding
  int sparse_block_size;           // block size for sparse attention
  int num_sparse_layout;           // number of sparse layout
  int stride_col_indices;          // shape of block_col_indices is [num_sparse_layout, stride_col_indices]
  int stride_row_indices;          // shape of block_row_indices is [num_sparse_layout, stride_row_indices]
  int max_rotary_sequence_length;  // max sequence length for rotary cos/sin cache
  int max_cache_sequence_length;   // max sequence length for kv cache buffer
};

}  // namespace contrib
}  // namespace onnxruntime
