// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/rocm/rocm_kernel.h"

using namespace onnxruntime::rocm;

namespace onnxruntime {
namespace contrib {
namespace rocm {

// AddGelu fuse Add + Gelu
class BiasGelu final : public RocmKernel {
 public:
  BiasGelu(const OpKernelInfo& info) : RocmKernel(info) {}
  Status ComputeInternal(OpKernelContext* context) const override;

 private:
  template <typename T>
  struct KernelLaunchDispatcher {
    void operator()(hipStream_t stream, int64_t input_size, int64_t bias_size, const Tensor& X, const Tensor& B,
                    Tensor& Y) const;
  };
};

}  // namespace rocm
}  // namespace contrib
}  // namespace onnxruntime
