Skip to content

Instantly share code, notes, and snippets.

@DiracSpace
Created September 2, 2025 20:08
Show Gist options
  • Save DiracSpace/5777fe2362b0266a76137500bd75db07 to your computer and use it in GitHub Desktop.
Save DiracSpace/5777fe2362b0266a76137500bd75db07 to your computer and use it in GitHub Desktop.
Custom implementations for running SQL commands and for building/testing connections with a result pattern.
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;
using System.Text.RegularExpressions;
internal sealed record SqlExecutionResult(
SqlExecutionResult.CommandState Status,
int CommandsExecuted,
IReadOnlyList<SqlRowExecutionResult>? RowsAffected)
{
public enum CommandState
{
Completed,
CompletedWithErrors,
CompleteFailure
}
internal bool HasFailed => Status is CommandState.CompleteFailure or CommandState.CompletedWithErrors;
internal List<SqlRowExecutionResult> FailedRows => RowsAffected
?.Where(r => !r.IsSuccessful)
?.ToList()
?? new List<SqlRowExecutionResult>();
internal string FailureMessages
{
get
{
return !RowsAffected?.Any(r => !r.IsSuccessful) == true
? string.Empty
: string.Join("; ", FailedRows.Select(r => r.CommandAndDetails));
}
}
internal static SqlExecutionResult OneFailed(string command, string details)
{
List<SqlRowExecutionResult> row = new()
{
SqlRowExecutionResult.Failure(command, details)
};
return CompleteFailure(1, row);
}
internal static SqlExecutionResult Completed(int commands)
=> new(CommandState.Completed, commands, RowsAffected: null);
internal static SqlExecutionResult CompletedWithErrors(int commands, List<SqlRowExecutionResult> rowsAffected)
=> new(CommandState.CompletedWithErrors, commands, rowsAffected);
internal static SqlExecutionResult CompleteFailure(int commands, IReadOnlyList<SqlRowExecutionResult> rowsAffected)
=> new(CommandState.CompleteFailure, commands, rowsAffected);
}
internal sealed record SqlRowExecutionResult(
bool IsSuccessful,
string? Command,
string? Details)
{
internal string CommandAndDetails
{
get
{
string result = string.Empty;
if (!string.IsNullOrWhiteSpace(Command))
result += $"Command: {Command}";
if (!string.IsNullOrWhiteSpace(Details))
result += $" Details: {Details}";
return result;
}
}
internal static SqlRowExecutionResult Success()
=> new(IsSuccessful: true, Command: null, Details: null);
internal static SqlRowExecutionResult Failure(string command, string details)
=> new(IsSuccessful: false, Command: command, Details: details);
}
internal interface ISqlCommandProvider
{
Task<SqlExecutionResult> TryApplySqlCommandAsync(string connectionString, string script,
CancellationToken cancellationToken = default);
}
// TODO: update to reuse the same connection
internal sealed class SqlCommandProvider : ISqlCommandProvider
{
private readonly ILogger<ISqlCommandProvider> _logger;
public SqlCommandProvider(ILogger<ISqlCommandProvider> logger)
{
_logger = logger;
}
public async Task<SqlExecutionResult> TryApplySqlCommandAsync(string connectionString, string script,
CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(connectionString, nameof(connectionString));
ArgumentNullException.ThrowIfNull(script, nameof(script));
cancellationToken.ThrowIfCancellationRequested();
_logger.LogInformation("Opening connection");
var results = new List<SqlRowExecutionResult>();
int commandsExecuted = 0;
try
{
await using SqlConnection sqlConnection = new(connectionString);
await sqlConnection.OpenAsync(cancellationToken);
_logger.LogInformation(
"Executing script: {ScriptPath}",
script
);
// split script on GO command
IEnumerable<string> commandStrings = Regex.Split(
script,
@"^\s*GO\s*$",
RegexOptions.Multiline | RegexOptions.IgnoreCase);
if (commandStrings is null || !commandStrings.Any())
{
_logger.LogWarning(
"Could not obtain any {Entity} from script {ScriptPath}",
nameof(SqlCommand),
script);
return SqlExecutionResult.CompleteFailure(0, new List<SqlRowExecutionResult>());
}
foreach (var commandString in commandStrings)
{
cancellationToken.ThrowIfCancellationRequested();
if (string.IsNullOrWhiteSpace(commandString))
{
_logger.LogDebug("Command string is empty, skipping.");
continue;
}
_logger.LogDebug("Executing command string.");
await using var command = new SqlCommand(commandString, sqlConnection);
var sqlRowResult = await TryRunRowAsync(command, cancellationToken);
results.Add(sqlRowResult);
}
_logger.LogDebug("Changes applied to database.");
bool hasFailures = results.Any(r => !r.IsSuccessful);
string message = hasFailures
? "There are failures in the contents."
: "The contents ran successfully.";
_logger.LogDebug("{Message}", message);
return hasFailures
? SqlExecutionResult.CompletedWithErrors(commandsExecuted, results)
: SqlExecutionResult.Completed(commandsExecuted);
}
catch (SqlException sqlException)
{
_logger.LogError(
sqlException,
"SQL operation failed at {Now}. Reason: {Reason}.",
DateTimeOffset.UtcNow,
sqlException.Message);
return SqlExecutionResult.CompleteFailure(commandsExecuted, results);
}
catch (Exception exception)
{
_logger.LogError(
exception,
"Unexpected failure at {Now}. Reason: {Reason}.",
DateTimeOffset.UtcNow,
exception.Message);
return SqlExecutionResult.CompleteFailure(commandsExecuted, results);
}
}
private async Task<SqlRowExecutionResult> TryRunRowAsync(SqlCommand command, CancellationToken cancellationToken)
{
try
{
await command.ExecuteNonQueryAsync(cancellationToken);
return SqlRowExecutionResult.Success();
}
catch (SqlException sqlException)
{
_logger.LogError(
sqlException,
"Failed at {Now}. Reason: {Reason}.",
DateTimeOffset.UtcNow,
sqlException.Message);
return SqlRowExecutionResult.Failure(command.CommandText, sqlException.Message);
}
}
}
using Microsoft.Data.SqlClient;
using Microsoft.Extensions.Logging;
using System.Data.Common;
internal sealed record SqlConnectionResult(
SqlConnectionResult.ConnectState Status,
string? Details)
{
public enum ConnectState
{
Connected,
Failure
}
public bool HasFailed => Status == ConnectState.Failure;
public static SqlConnectionResult Connected()
=> new(ConnectState.Connected, Details: null);
public static SqlConnectionResult Failure(string details)
=> new(ConnectState.Failure, details);
}
internal interface ISqlConnectionProvider
{
string GetConnectionString(
string source,
string user,
string password,
string? database = null,
int? timeout = 15,
bool encrypt = true,
bool trustServer = true
);
bool IsConnectionStringValid(string sqlConnection);
Task<SqlConnectionResult> TryConnectAsync(string sqlConnection, CancellationToken cancellationToken = default);
}
internal sealed class SqlConnectionProvider : ISqlConnectionProvider
{
private readonly ILogger<ISqlConnectionProvider> _logger;
public SqlConnectionProvider(ILogger<ISqlConnectionProvider> logger)
{
_logger = logger;
}
public string GetConnectionString(string source, string user, string password, string? database = null,
int? timeout = 15, bool encrypt = true, bool trustServer = true)
{
ArgumentNullException.ThrowIfNull(source, nameof(source));
ArgumentNullException.ThrowIfNull(user, nameof(user));
ArgumentNullException.ThrowIfNull(password, nameof(password));
string? connection;
try
{
_logger.LogDebug("Building SQL connection string.");
var builder = new SqlConnectionStringBuilder
{
DataSource = source,
UserID = user,
Password = password,
Encrypt = encrypt,
TrustServerCertificate = trustServer,
};
if (timeout is not null)
{
_logger.LogDebug("Adding connect timeout to SQL connection builder.");
builder.ConnectTimeout = timeout.Value;
}
if (!string.IsNullOrWhiteSpace(database))
{
_logger.LogDebug("Adding database name to SQL connection builder.");
builder.InitialCatalog = database;
}
connection = builder.ConnectionString;
}
catch (Exception exception)
{
_logger.LogError(
exception,
"Could not generate SQL connection string because: {ExceptionMessage}",
exception.Message
);
throw;
}
return connection;
}
/// <summary>
/// Tries to parse specific SQL connection source for
/// validation
/// </summary>
/// <param name="sqlConnection"></param>
/// <returns></returns>
public bool IsConnectionStringValid(string sqlConnection)
{
if (string.IsNullOrWhiteSpace(sqlConnection))
{
_logger.LogInformation(
"No SQL connection string provided"
);
return false;
}
try
{
_logger.LogInformation("Validating SQL connection string");
var unused = new DbConnectionStringBuilder
{
ConnectionString = sqlConnection
};
}
catch (SqlException sqlException)
{
_logger.LogError(
sqlException,
"SQL connection string is not valid because: {Message}",
sqlException.Message
);
return false;
}
_logger.LogInformation("SQL connection string valid");
return true;
}
/// <summary>
/// Tests connection to specific SQL connection source
/// </summary>
/// <param name="sqlConnection"></param>
/// <returns></returns>
private async Task<bool> CanConnect(string sqlConnection, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(sqlConnection, nameof(sqlConnection));
cancellationToken.ThrowIfCancellationRequested();
_logger.LogInformation("Trying to connect to SQL database with connection string provided");
await using var connection = new SqlConnection(sqlConnection);
await connection.OpenAsync(cancellationToken);
_logger.LogInformation("Can connect to SQL database");
return true;
}
public async Task<SqlConnectionResult> TryConnectAsync(string sqlConnection, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(sqlConnection, nameof(sqlConnection));
cancellationToken.ThrowIfCancellationRequested();
bool result;
try
{
_logger.LogInformation("Trying to connect to SQL database");
result = await CanConnect(sqlConnection, cancellationToken);
}
catch (SqlException sqlException)
{
_logger.LogError(
sqlException,
"Could not connect to SQL database because: {Message}",
sqlException.Message);
return SqlConnectionResult.Failure(sqlException.Message);
}
_logger.LogInformation(
"Is connection successful: {Result}",
result);
return SqlConnectionResult.Connected();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment