diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d19255d52..b6545e072 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -41,6 +41,7 @@ jobs: with: dotnet-version: | 6.0.x + 7.0.x 8.0.x 9.0.x dotnet-quality: 'ga' @@ -103,6 +104,7 @@ jobs: with: dotnet-version: | 6.0.x + 7.0.x 8.0.x 9.0.x dotnet-quality: 'ga' @@ -163,6 +165,7 @@ jobs: with: dotnet-version: | 6.0.x + 7.0.x 8.0.x 9.0.x dotnet-quality: 'ga' diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index e3303bdee..81489877e 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -14,6 +14,7 @@ using Snowflake.Data.Client; using Snowflake.Data.Core; using Snowflake.Data.Core.Session; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; using Snowflake.Data.Tests.Mock; using Snowflake.Data.Tests.Util; @@ -2271,6 +2272,52 @@ public void TestUseMultiplePoolsConnectionPoolByDefault() Assert.AreEqual(ConnectionPoolType.MultipleConnectionPool, poolVersion); } + [Test] + [Ignore("This test requires manual interaction and therefore cannot be run in CI")] // to enroll to mfa authentication edit your user profile + public void TestMFATokenCachingWithPasscodeFromConnectionString() + { + // Use a connection with MFA enabled and set passcode property for mfa authentication. e.g. ConnectionString + ";authenticator=username_password_mfa;passcode=(set proper passcode)" + // ACCOUNT PARAMETER ALLOW_CLIENT_MFA_CACHING should be set to true in the account. + // On Mac/Linux OS the default credential manager is a file based one. Uncomment the following line to test in memory implementation. + // SnowflakeCredentialManagerFactory.UseInMemoryCredentialManager(); + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString + = ConnectionString + + ";authenticator=username_password_mfa;application=DuoTest;minPoolSize=0;passcode=(set proper passcode)"; + + + // Authenticate to retrieve and store the token if doesn't exist or invalid + Task connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + Assert.AreEqual(ConnectionState.Open, conn.State); + } + } + + [Test] + [Ignore("Requires manual steps and environment with mfa authentication enrolled")] // to enroll to mfa authentication edit your user profile + public void TestMfaWithPasswordConnectionUsingPasscodeWithSecureString() + { + // Use a connection with MFA enabled and Passcode property on connection instance. + // ACCOUNT PARAMETER ALLOW_CLIENT_MFA_CACHING should be set to true in the account. + // On Mac/Linux OS the default credential manager is a file based one. Uncomment the following line to test in memory implementation. + // SnowflakeCredentialManagerFactory.UseInMemoryCredentialManager(); + // arrange + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.Passcode = SecureStringHelper.Encode("$(set proper passcode)"); + // manual action: stop here in breakpoint to provide proper passcode by: conn.Passcode = SecureStringHelper.Encode("..."); + conn.ConnectionString = ConnectionString + "minPoolSize=2;application=DuoTest;"; + + // act + Task connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + + // assert + Assert.AreEqual(ConnectionState.Open, conn.State); + } + } + [Test] [TestCase("connection_timeout=5;")] [TestCase("")] diff --git a/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs new file mode 100644 index 000000000..163124b7d --- /dev/null +++ b/Snowflake.Data.Tests/Mock/MockLoginMFATokenCacheRestRequester.cs @@ -0,0 +1,89 @@ +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Core; + +namespace Snowflake.Data.Tests.Mock +{ + using Microsoft.IdentityModel.Tokens; + + class MockLoginMFATokenCacheRestRequester: IMockRestRequester + { + internal Queue LoginRequests { get; } = new(); + + internal Queue LoginResponses { get; } = new(); + + public T Get(IRestRequest request) + { + return Task.Run(async () => await (GetAsync(request, CancellationToken.None)).ConfigureAwait(false)).Result; + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + return Task.FromResult((T)(object)null); + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + return Task.FromResult(null); + } + + public HttpResponseMessage Get(IRestRequest request) + { + return null; + } + + public T Post(IRestRequest postRequest) + { + return Task.Run(async () => await (PostAsync(postRequest, CancellationToken.None)).ConfigureAwait(false)).Result; + } + + public Task PostAsync(IRestRequest postRequest, CancellationToken cancellationToken) + { + SFRestRequest sfRequest = (SFRestRequest)postRequest; + if (sfRequest.jsonBody is LoginRequest) + { + LoginRequests.Enqueue((LoginRequest) sfRequest.jsonBody); + var responseData = this.LoginResponses.IsNullOrEmpty() ? new LoginResponseData() + { + token = "session_token", + masterToken = "master_token", + authResponseSessionInfo = new SessionInfo(), + nameValueParameter = new List() + } : this.LoginResponses.Dequeue(); + var authnResponse = new LoginResponse + { + data = responseData, + success = true + }; + + // login request return success + return Task.FromResult((T)(object)authnResponse); + } + else if (sfRequest.jsonBody is CloseResponse) + { + var authnResponse = new CloseResponse() + { + success = true + }; + + // login request return success + return Task.FromResult((T)(object)authnResponse); + } + throw new NotImplementedException(); + } + + public void setHttpClient(HttpClient httpClient) + { + // Nothing to do + } + + public void Reset() + { + LoginRequests.Clear(); + LoginResponses.Clear(); + } + } +} diff --git a/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs b/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs index c6d8f0698..2f7d0efc0 100644 --- a/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs +++ b/Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs @@ -78,10 +78,10 @@ public override Task OpenAsync(CancellationToken cancellationToken) cancellationToken); } - + private void SetMockSession() { - SfSession = new SFSession(ConnectionString, Password, _restRequester); + SfSession = new SFSession(ConnectionString, Password, Passcode, EasyLoggingStarter.Instance, _restRequester); _connectionTimeout = (int)SfSession.connectionTimeout.TotalSeconds; @@ -92,7 +92,7 @@ private void OnSessionEstablished() { _connectionState = ConnectionState.Open; } - + protected override bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus) { return false; diff --git a/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs b/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs index 0405c7009..8c385ad95 100755 --- a/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs @@ -33,7 +33,7 @@ public void BeforeTest() // by default generate Int32 values from 1 to RowCount PrepareTestCase(SFDataType.FIXED, 0, Enumerable.Range(1, RowCount).ToArray()); } - + [Test] public void TestResultFormatIsArrow() { @@ -140,7 +140,7 @@ public void TestGetValueReturnsNull() var arrowResultSet = new ArrowResultSet(responseData, sfStatement, new CancellationToken()); arrowResultSet.Next(); - + Assert.AreEqual(true, arrowResultSet.IsDBNull(0)); Assert.AreEqual(DBNull.Value, arrowResultSet.GetValue(0)); } @@ -152,7 +152,7 @@ public void TestGetDecimal() TestGetNumber(testValues); } - + [Test] public void TestGetNumber64() { @@ -165,7 +165,7 @@ public void TestGetNumber64() public void TestGetNumber32() { var testValues = new int[] { 0, 100, -100, Int32.MaxValue, Int32.MinValue }; - + TestGetNumber(testValues); } @@ -176,7 +176,7 @@ public void TestGetNumber16() TestGetNumber(testValues); } - + [Test] public void TestGetNumber8() { @@ -200,7 +200,7 @@ private void TestGetNumber(IEnumerable testValues) Assert.AreEqual(expectedValue, _arrowResultSet.GetDecimal(ColumnIndex)); Assert.AreEqual(expectedValue, _arrowResultSet.GetDouble(ColumnIndex)); Assert.AreEqual(expectedValue, _arrowResultSet.GetFloat(ColumnIndex)); - + if (expectedValue >= Int64.MinValue && expectedValue <= Int64.MaxValue) { // get integer value @@ -230,7 +230,7 @@ public void TestGetBoolean() var testValues = new bool[] { true, false }; PrepareTestCase(SFDataType.BOOLEAN, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -245,7 +245,7 @@ public void TestGetReal() var testValues = new double[] { 0, Double.MinValue, Double.MaxValue }; PrepareTestCase(SFDataType.REAL, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -253,7 +253,7 @@ public void TestGetReal() Assert.AreEqual(testValue, _arrowResultSet.GetDouble(ColumnIndex)); } } - + [Test] public void TestGetText() { @@ -264,7 +264,7 @@ public void TestGetText() }; PrepareTestCase(SFDataType.TEXT, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -272,7 +272,7 @@ public void TestGetText() Assert.AreEqual(testValue, _arrowResultSet.GetString(ColumnIndex)); } } - + [Test] public void TestGetTextWithOneChar() { @@ -290,14 +290,14 @@ public void TestGetTextWithOneChar() #endif PrepareTestCase(SFDataType.TEXT, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); Assert.AreEqual(testValue, _arrowResultSet.GetChar(ColumnIndex)); } } - + [Test] public void TestGetArray() { @@ -308,7 +308,7 @@ public void TestGetArray() }; PrepareTestCase(SFDataType.ARRAY, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -320,7 +320,7 @@ public void TestGetArray() Assert.AreEqual(testValue.Length, str.Length); } } - + [Test] public void TestGetBinary() { @@ -342,7 +342,7 @@ public void TestGetBinary() Assert.AreEqual(testValue[j], buffer[j], "position " + j); } } - + [Test] public void TestGetDate() { @@ -354,7 +354,7 @@ public void TestGetDate() }; PrepareTestCase(SFDataType.DATE, 0, testValues); - + foreach (var testValue in testValues) { _arrowResultSet.Next(); @@ -362,7 +362,7 @@ public void TestGetDate() Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex)); } } - + [Test] public void TestGetTime() { @@ -384,7 +384,7 @@ public void TestGetTime() Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex)); Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex)); } - } + } } [Test] @@ -513,10 +513,10 @@ private QueryExecResponseData PrepareResponseData(RecordBatch recordBatch, SFDat return new QueryExecResponseData { rowType = recordBatch.Schema.FieldsList - .Select(col => + .Select(col => new ExecResponseRowType { - name = col.Name, + name = col.Name, type = sfType.ToString(), scale = scale }).ToList(), @@ -531,7 +531,7 @@ private string ConvertToBase64String(RecordBatch recordBatch) { if (recordBatch == null) return ""; - + using (var stream = new MemoryStream()) { using (var writer = new ArrowStreamWriter(stream, recordBatch.Schema)) @@ -542,12 +542,12 @@ private string ConvertToBase64String(RecordBatch recordBatch) return Convert.ToBase64String(stream.ToArray()); } } - + private SFStatement PrepareStatement() { SFSession session = new SFSession("user=user;password=password;account=account;", null); return new SFStatement(session); } - + } } diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs new file mode 100644 index 000000000..a739e759e --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerMFATest.cs @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + + + +namespace Snowflake.Data.Tests.UnitTests +{ + using System; + using System.Linq; + using System.Security; + using Mock; + using NUnit.Framework; + using Snowflake.Data.Core; + using Snowflake.Data.Core.Session; + using Snowflake.Data.Client; + using Snowflake.Data.Core.Tools; + using Snowflake.Data.Tests.Util; + + [TestFixture, NonParallelizable] + class ConnectionPoolManagerMFATest + { + private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); + private const string ConnectionStringMFACache = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;authenticator=username_password_mfa"; + private static PoolConfig s_poolConfig; + private static MockLoginMFATokenCacheRestRequester s_restRequester; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_poolConfig = new PoolConfig(); + s_restRequester = new MockLoginMFATokenCacheRestRequester(); + SnowflakeDbConnectionPool.ForceConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + SessionPool.SessionFactory = new MockSessionFactoryMFA(s_restRequester); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_poolConfig.Reset(); + SessionPool.SessionFactory = new SessionFactory(); + } + + [SetUp] + public void BeforeEach() + { + _connectionPoolManager.ClearAllPools(); + s_restRequester.Reset(); + } + + [Test] + public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringUsingMFA() + { + // Arrange + var testToken = "testToken1234"; + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + s_restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + // Act + var session = _connectionPoolManager.GetSession(ConnectionStringMFACache, null, null); + + // Assert + Awaiter.WaitUntilConditionOrTimeout(() => s_restRequester.LoginRequests.Count == 2, TimeSpan.FromSeconds(15)); + Assert.AreEqual(2, s_restRequester.LoginRequests.Count); + var loginRequest1 = s_restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(string.Empty, loginRequest1.data.Token); + Assert.AreEqual(testToken, SecureStringHelper.Decode(session._mfaToken)); + Assert.IsTrue(loginRequest1.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("passcode", loginRequest1.data.extAuthnDuoMethod); + var loginRequest2 = s_restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(testToken, loginRequest2.data.Token); + Assert.IsTrue(loginRequest2.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value1) && (bool)value1); + Assert.AreEqual("passcode", loginRequest2.data.extAuthnDuoMethod); + } + + [Test] + public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=true"; + // Act and assert + var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null,null)); + Assert.That(thrown.Message, Does.Contain("Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication")); + } + + [Test] + public void TestPoolManagerShouldThrowExceptionIfForcePoolingWithPasscodeAsSecureStringNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;POOLINGENABLED=true"; + // Act and assert + var thrown = Assert.Throws(() =>_connectionPoolManager.GetSession(connectionString, null,SecureStringHelper.Encode("12345"))); + Assert.That(thrown.Message, Does.Contain("Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication")); + } + + [Test] + public void TestPoolManagerShouldNotThrowExceptionIfForcePoolingWithPasscodeNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=2;passcode=12345;POOLINGENABLED=false"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null, null)); + } + + [Test] + public void TestPoolManagerShouldNotThrowExceptionIfMinPoolSizeZeroNotUsingMFATokenCacheAuthenticator() + { + // Arrange + var connectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=0;passcode=12345;POOLINGENABLED=true"; + // Act and assert + Assert.DoesNotThrow(() =>_connectionPoolManager.GetSession(connectionString, null, null)); + } + } + + class MockSessionFactoryMFA : ISessionFactory + { + private readonly IMockRestRequester restRequester; + + public MockSessionFactoryMFA(IMockRestRequester restRequester) + { + this.restRequester = restRequester; + } + + public SFSession NewSession(string connectionString, SecureString password, SecureString passcode) + { + return new SFSession(connectionString, password, passcode, EasyLoggingStarter.Instance, restRequester); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs index b53487d60..0293d6571 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -111,7 +111,7 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() public void TestGetSessionWorksForSpecifiedConnectionString() { // Act - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); @@ -122,7 +122,7 @@ public void TestGetSessionWorksForSpecifiedConnectionString() public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() { // Act - var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, CancellationToken.None); + var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, null, CancellationToken.None); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); @@ -133,7 +133,7 @@ public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() public void TestCountingOfSessionProvidedByPool() { // Act - _connectionPoolManager.GetSession(ConnectionString1, null); + _connectionPoolManager.GetSession(ConnectionString1, null, null); // Assert var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); @@ -144,7 +144,7 @@ public void TestCountingOfSessionProvidedByPool() public void TestCountingOfSessionReturnedBackToPool() { // Arrange - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null); // Act _connectionPoolManager.AddSession(sfSession); @@ -285,8 +285,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes() { // Arrange - EnsurePoolSize(ConnectionString1, null, 2); - EnsurePoolSize(ConnectionString2, null, 3); + EnsurePoolSize(ConnectionString1, null, null,2); + EnsurePoolSize(ConnectionString2, null, null, 3); // act var poolSize = _connectionPoolManager.GetCurrentPoolSize(); @@ -300,7 +300,7 @@ public void TestReturnPoolForSecurePassword() { // arrange const string AnotherPassword = "anotherPassword"; - EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 1); + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 1); // act var pool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, SecureStringHelper.Encode(AnotherPassword)); // a new pool has been created because the password is different @@ -315,9 +315,9 @@ public void TestReturnDifferentPoolWhenPasswordProvidedInDifferentWay() { // arrange var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={SecureStringHelper.Decode(_password3)}"; - EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 2); - EnsurePoolSize(connectionStringWithPassword, null, 5); - EnsurePoolSize(connectionStringWithPassword, _password3, 8); + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 2); + EnsurePoolSize(connectionStringWithPassword, null, null, 5); + EnsurePoolSize(connectionStringWithPassword, _password3, null, 8); // act var pool1 = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3); @@ -360,13 +360,13 @@ public void TestPoolDoesNotSerializePassword() Assert.IsFalse(serializedPool.Contains(password)); } - private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize) + private void EnsurePoolSize(string connectionString, SecureString password, SecureString passcode, int requiredCurrentSize) { var sessionPool = _connectionPoolManager.GetPool(connectionString, password); sessionPool.SetMaxPoolSize(requiredCurrentSize); for (var i = 0; i < requiredCurrentSize; i++) { - _connectionPoolManager.GetSession(connectionString, password); + _connectionPoolManager.GetSession(connectionString, password, passcode); } Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize()); } @@ -374,9 +374,9 @@ private void EnsurePoolSize(string connectionString, SecureString password, int class MockSessionFactory : ISessionFactory { - public SFSession NewSession(string connectionString, SecureString password) + public SFSession NewSession(string connectionString, SecureString password, SecureString passcode) { - var mockSfSession = new Mock(connectionString, password); + var mockSfSession = new Mock(connectionString, password, passcode, EasyLoggingStarter.Instance); mockSfSession.Setup(x => x.Open()).Verifiable(); mockSfSession.Setup(x => x.OpenAsync(default)).Returns(Task.FromResult(this)); mockSfSession.Setup(x => x.IsNotOpen()).Returns(false); diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFBaseCredentialManagerTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFBaseCredentialManagerTest.cs new file mode 100644 index 000000000..da981549a --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFBaseCredentialManagerTest.cs @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using NUnit.Framework; +using Snowflake.Data.Client; + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + public abstract class SFBaseCredentialManagerTest + { + protected ISnowflakeCredentialManager _credentialManager; + + [Test] + public void TestSavingAndRemovingCredentials() + { + // arrange + var key = "mockKey"; + var expectedToken = "token"; + + // act + _credentialManager.SaveCredentials(key, expectedToken); + + // assert + Assert.AreEqual(expectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + } + + [Test] + public void TestSavingCredentialsForAnExistingKey() + { + // arrange + var key = "mockKey"; + var firstExpectedToken = "mockToken1"; + var secondExpectedToken = "mockToken2"; + + // act + _credentialManager.SaveCredentials(key, firstExpectedToken); + + // assert + Assert.AreEqual(firstExpectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.SaveCredentials(key, secondExpectedToken); + + // assert + Assert.AreEqual(secondExpectedToken, _credentialManager.GetCredentials(key)); + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + + } + + [Test] + public void TestRemovingCredentialsForKeyThatDoesNotExist() + { + // arrange + var key = "mockKey"; + + // act + _credentialManager.RemoveCredentials(key); + + // assert + Assert.IsTrue(string.IsNullOrEmpty(_credentialManager.GetCredentials(key))); + } + + [Test] + public void TestGetCredentialsForProperKey() + { + // arrange + var key = "key"; + var anotherKey = "anotherKey"; + var token = "token"; + var anotherToken = "anotherToken"; + _credentialManager.SaveCredentials(key, token); + _credentialManager.SaveCredentials(anotherKey, anotherToken); + + // act + var result = _credentialManager.GetCredentials(key); + + // assert + Assert.AreEqual(token, result); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerFileImplTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerFileImplTest.cs new file mode 100644 index 000000000..079c02901 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerFileImplTest.cs @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.IO; +using System.Security; +using Mono.Unix; +using Mono.Unix.Native; +using Moq; +using NUnit.Framework; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + [TestFixture, NonParallelizable] + [Platform(Exclude = "Win")] + public class SFCredentialManagerFileImplTest : SFBaseCredentialManagerTest + { + [ThreadStatic] + private static Mock t_fileOperations; + + [ThreadStatic] + private static Mock t_directoryOperations; + + [ThreadStatic] + private static Mock t_unixOperations; + + [ThreadStatic] + private static Mock t_environmentOperations; + + private const string CustomJsonDir = "testdirectory"; + + private static readonly string s_customJsonPath = Path.Combine(CustomJsonDir, SFCredentialManagerFileStorage.CredentialCacheFileName); + + private static readonly string s_customLockPath = Path.Combine(CustomJsonDir, SFCredentialManagerFileStorage.CredentialCacheLockName); + + private const int UserId = 1; + + [SetUp] + public void SetUp() + { + t_fileOperations = new Mock(); + t_directoryOperations = new Mock(); + t_unixOperations = new Mock(); + t_environmentOperations = new Mock(); + _credentialManager = SFCredentialManagerFileImpl.Instance; + } + + [TearDown] + public void CleanAll() + { + if (SFCredentialManagerFileImpl.Instance._fileStorage != null) + { + File.Delete(SFCredentialManagerFileImpl.Instance._fileStorage.JsonCacheFilePath); + } + } + + [Test] + public void TestThatThrowsErrorWhenCacheFailToCreateCacheFile() + { + // arrange + t_directoryOperations + .Setup(d => d.Exists(s_customJsonPath)) + .Returns(false); + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR)) + .Returns(-1); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_directoryOperations + .Setup(d => d.GetParentDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryInformation(true, DateTime.UtcNow)); + t_unixOperations + .Setup(u => u.GetDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryUnixInformation(CustomJsonDir, true, FileAccessPermissions.UserReadWriteExecute, UserId)); + t_unixOperations + .Setup(u => u.GetCurrentUserId()) + .Returns(UserId); + t_directoryOperations + .Setup(d => d.GetDirectoryInfo(s_customLockPath)) + .Returns(new DirectoryInformation(false, null)); + _credentialManager = new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object); + + // act + var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); + + // assert + Assert.That(thrown.Message, Does.Contain("Failed to create the JSON token cache file")); + } + + [Test] + public void TestThatThrowsErrorWhenCacheFileCanBeAccessedByOthers() + { + // arrange + var tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(tempDirectory); + _credentialManager = CreateFileCredentialManagerWithMockedEnvironmentalVariables(); + try + { + DirectoryOperations.Instance.CreateDirectory(tempDirectory); + UnixOperations.Instance.CreateFileWithPermissions(Path.Combine(tempDirectory, SFCredentialManagerFileStorage.CredentialCacheFileName), FilePermissions.ALLPERMS); + + // act + var thrown = Assert.Throws(() => _credentialManager.SaveCredentials("key", "token")); + + // assert + Assert.That(thrown.Message, Does.Contain("Attempting to read or write a file with too broad permissions assigned")); + } + finally + { + DirectoryOperations.Instance.Delete(tempDirectory, true); + } + } + + [Test] + public void TestThatJsonFileIsCheckedIfAlreadyExists() + { + // arrange + t_unixOperations + .Setup(u => u.CreateFileWithPermissions(s_customJsonPath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR)) + .Returns(0); + t_unixOperations + .Setup(u => u.GetFilePermissions(s_customJsonPath)) + .Returns(FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(false) + .Returns(true); + t_directoryOperations + .Setup(d => d.GetParentDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryInformation(true, DateTime.UtcNow)); + t_unixOperations + .Setup(u => u.GetDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryUnixInformation(CustomJsonDir, true, FileAccessPermissions.UserReadWriteExecute, UserId)); + t_unixOperations + .Setup(u => u.GetCurrentUserId()) + .Returns(UserId); + t_directoryOperations + .Setup(d => d.GetDirectoryInfo(s_customLockPath)) + .Returns(new DirectoryInformation(false, null)); + _credentialManager = new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object); + + // act + _credentialManager.SaveCredentials("key", "token"); + + // assert + t_fileOperations.Verify(f => f.Exists(s_customJsonPath), Times.Exactly(2)); + } + + [Test] + public void TestWritingIsUnavailableIfFailedToCreateDirLock() + { + // arrange + t_unixOperations + .Setup(u => u.GetFilePermissions(s_customJsonPath)) + .Returns(FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(false) + .Returns(true); + t_directoryOperations + .Setup(d => d.GetDirectoryInfo(s_customLockPath)) + .Returns(new DirectoryInformation(false, null)); + t_directoryOperations + .Setup(d => d.GetParentDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryInformation(true, DateTime.UtcNow)); + t_unixOperations + .Setup(u => u.GetDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryUnixInformation(CustomJsonDir, true, FileAccessPermissions.UserReadWriteExecute, UserId)); + t_unixOperations + .Setup(u => u.GetCurrentUserId()) + .Returns(UserId); + t_unixOperations + .Setup(u => u.CreateDirectoryWithPermissions(s_customLockPath, SFCredentialManagerFileImpl.CredentialCacheLockDirPermissions)) + .Returns(-1); + _credentialManager = new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object); + + // act + _credentialManager.SaveCredentials("key", "token"); + + // assert + t_fileOperations.Verify(f => f.Write(s_customJsonPath, It.IsAny(), It.IsAny>()), Times.Never); + } + + [Test] + public void TestReadingIsUnavailableIfFailedToCreateDirLock() + { + // arrange + t_unixOperations + .Setup(u => u.GetFilePermissions(s_customJsonPath)) + .Returns(FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(CustomJsonDir); + t_fileOperations + .SetupSequence(f => f.Exists(s_customJsonPath)) + .Returns(false) + .Returns(true); + t_unixOperations + .Setup(u => u.CreateDirectoryWithPermissions(s_customLockPath, SFCredentialManagerFileImpl.CredentialCacheLockDirPermissions)) + .Returns(-1); + t_directoryOperations + .Setup(d => d.GetParentDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryInformation(true, DateTime.UtcNow)); + t_unixOperations + .Setup(u => u.GetDirectoryInfo(CustomJsonDir)) + .Returns(new DirectoryUnixInformation(CustomJsonDir, true, FileAccessPermissions.UserReadWriteExecute, UserId)); + t_unixOperations + .Setup(u => u.GetCurrentUserId()) + .Returns(UserId); + t_directoryOperations + .Setup(d => d.GetDirectoryInfo(s_customLockPath)) + .Returns(new DirectoryInformation(false, null)); + _credentialManager = new SFCredentialManagerFileImpl(t_fileOperations.Object, t_directoryOperations.Object, t_unixOperations.Object, t_environmentOperations.Object); + + // act + _credentialManager.GetCredentials("key"); + + // assert + t_fileOperations.Verify(f => f.ReadAllText(s_customJsonPath, It.IsAny>()), Times.Never); + } + + [Test] + public void TestReadingAndWritingAreUnavailableIfDirLockExists() + { + // arrange + var tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(tempDirectory); + _credentialManager = CreateFileCredentialManagerWithMockedEnvironmentalVariables(); + try + { + DirectoryOperations.Instance.CreateDirectory(tempDirectory); + DirectoryOperations.Instance.CreateDirectory(Path.Combine(tempDirectory, SFCredentialManagerFileStorage.CredentialCacheLockName)); + + // act + _credentialManager.SaveCredentials("key", "token"); + var result = _credentialManager.GetCredentials("key"); + + // assert + Assert.AreEqual(string.Empty, result); + } + finally + { + DirectoryOperations.Instance.Delete(tempDirectory, true); + } + } + + [Test] + public void TestChangeCacheDirPermissionsWhenInsecure() + { + // arrange + var tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(tempDirectory); + _credentialManager = CreateFileCredentialManagerWithMockedEnvironmentalVariables(); + try + { + DirectoryOperations.Instance.CreateDirectory(tempDirectory); + UnixOperations.Instance.ChangePermissions(tempDirectory, FileAccessPermissions.UserReadWriteExecute | FileAccessPermissions.GroupRead); + + // act + _credentialManager.SaveCredentials("key", "token"); + var result = _credentialManager.GetCredentials("key"); + + // assert + Assert.AreEqual("token", result); + Assert.AreEqual(FileAccessPermissions.UserReadWriteExecute, UnixOperations.Instance.GetDirectoryInfo(tempDirectory).Permissions); + } + finally + { + DirectoryOperations.Instance.Delete(tempDirectory, true); + } + } + + [Test] + public void TestCreateDirectoryWithSecurePermissions() + { + // arrange + var tempDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(tempDirectory); + _credentialManager = CreateFileCredentialManagerWithMockedEnvironmentalVariables(); + try + { + // act + _credentialManager.SaveCredentials("key", "token"); + var result = _credentialManager.GetCredentials("key"); + + // assert + Assert.AreEqual("token", result); + Assert.AreEqual(FileAccessPermissions.UserReadWriteExecute, UnixOperations.Instance.GetDirectoryInfo(tempDirectory).Permissions); + } + finally + { + DirectoryOperations.Instance.Delete(tempDirectory, true); + } + } + + private SFCredentialManagerFileImpl CreateFileCredentialManagerWithMockedEnvironmentalVariables() => + new (FileOperations.Instance, DirectoryOperations.Instance, UnixOperations.Instance, t_environmentOperations.Object); + } +} diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerFileStorageTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerFileStorageTest.cs new file mode 100644 index 000000000..4be5f9513 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerFileStorageTest.cs @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.IO; +using NUnit.Framework; +using Moq; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Core.Tools; + + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + [TestFixture] + public class SFCredentialManagerFileStorageTest + { + private const string SnowflakeCacheLocation = "/Users/snowflake/cache"; + private const string CommonCacheLocation = "/Users/snowflake/.cache"; + private const string HomeLocation = "/Users/snowflake"; + + [ThreadStatic] + private static Mock t_environmentOperations; + + [SetUp] + public void SetUp() + { + t_environmentOperations = new Mock(); + } + + [Test] + public void TestChooseLocationFromSnowflakeCacheEnvironmentVariable() + { + // arrange + MockSnowflakeCacheEnvironmentVariable(); + MockCommonCacheEnvironmentVariable(); + MockHomeLocation(); + + // act + var fileStorage = new SFCredentialManagerFileStorage(t_environmentOperations.Object); + + // assert + AssertFileStorageForLocation(SnowflakeCacheLocation, fileStorage); + } + + [Test] + public void TestChooseLocationFromCommonCacheEnvironmentVariable() + { + // arrange + MockCommonCacheEnvironmentVariable(); + MockHomeLocation(); + var expectedLocation = Path.Combine(CommonCacheLocation, SFCredentialManagerFileStorage.CredentialCacheDirName); + + // act + var fileStorage = new SFCredentialManagerFileStorage(t_environmentOperations.Object); + + // assert + AssertFileStorageForLocation(expectedLocation, fileStorage); + } + + [Test] + public void TestChooseLocationFromHomeFolder() + { + // arrange + MockHomeLocation(); + var expectedLocation = Path.Combine(HomeLocation, SFCredentialManagerFileStorage.CommonCacheDirectoryName, SFCredentialManagerFileStorage.CredentialCacheDirName); + + // act + var fileStorage = new SFCredentialManagerFileStorage(t_environmentOperations.Object); + + // assert + AssertFileStorageForLocation(expectedLocation, fileStorage); + } + + [Test] + public void TestFailWhenLocationCannotBeIdentified() + { + // act + var thrown = Assert.Throws(() => new SFCredentialManagerFileStorage(t_environmentOperations.Object)); + + // assert + Assert.That(thrown.Message, Contains.Substring("Unable to identify credential cache directory")); + } + + private void AssertFileStorageForLocation(string directory, SFCredentialManagerFileStorage fileStorage) + { + Assert.NotNull(fileStorage); + Assert.AreEqual(directory, fileStorage.JsonCacheDirectory); + Assert.AreEqual(Path.Combine(directory, SFCredentialManagerFileStorage.CredentialCacheFileName), fileStorage.JsonCacheFilePath); + Assert.AreEqual(Path.Combine(directory, SFCredentialManagerFileStorage.CredentialCacheLockName), fileStorage.JsonCacheLockPath); + } + + private void MockSnowflakeCacheEnvironmentVariable() + { + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CredentialCacheDirectoryEnvironmentName)) + .Returns(SnowflakeCacheLocation); + } + + private void MockCommonCacheEnvironmentVariable() + { + t_environmentOperations + .Setup(e => e.GetEnvironmentVariable(SFCredentialManagerFileStorage.CommonCacheDirectoryEnvironmentName)) + .Returns(CommonCacheLocation); + } + + private void MockHomeLocation() + { + t_environmentOperations + .Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile)) + .Returns(HomeLocation); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerInMemoryImplTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerInMemoryImplTest.cs new file mode 100644 index 000000000..09c9a51bb --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerInMemoryImplTest.cs @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using NUnit.Framework; +using Snowflake.Data.Core.CredentialManager.Infrastructure; + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + [TestFixture, NonParallelizable] + public class SFCredentialManagerInMemoryImplTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerInMemoryImpl.Instance; + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerWindowsNativeImplTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerWindowsNativeImplTest.cs new file mode 100644 index 000000000..a954a6e5d --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SFCredentialManagerWindowsNativeImplTest.cs @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using NUnit.Framework; +using Snowflake.Data.Core.CredentialManager.Infrastructure; + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + [TestFixture, NonParallelizable] + [Platform("Win")] + public class SFCredentialManagerWindowsNativeImplTest : SFBaseCredentialManagerTest + { + [SetUp] + public void SetUp() + { + _credentialManager = SFCredentialManagerWindowsNativeImpl.Instance; + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/CredentialManager/SnowflakeCredentialManagerFactoryTest.cs b/Snowflake.Data.Tests/UnitTests/CredentialManager/SnowflakeCredentialManagerFactoryTest.cs new file mode 100644 index 000000000..498d7fafe --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/CredentialManager/SnowflakeCredentialManagerFactoryTest.cs @@ -0,0 +1,94 @@ +using System; +using System.Runtime.InteropServices; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.CredentialManager.Infrastructure; + +namespace Snowflake.Data.Tests.UnitTests.CredentialManager +{ + [TestFixture, NonParallelizable] + public class SnowflakeCredentialManagerFactoryTest + { + [TearDown] + public void TearDown() + { + SnowflakeCredentialManagerFactory.UseDefaultCredentialManager(); + } + + [Test] + public void TestUsingDefaultCredentialManager() + { + // arrange + SnowflakeCredentialManagerFactory.UseDefaultCredentialManager(); + + // act + var credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // assert + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + Assert.IsInstanceOf(credentialManager); + } + else + { + Assert.IsInstanceOf(credentialManager); + } + } + + [Test] + public void TestSettingCustomCredentialManager() + { + // arrange + SnowflakeCredentialManagerFactory.SetCredentialManager(SFCredentialManagerInMemoryImpl.Instance); + + // act + var credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // assert + Assert.IsInstanceOf(credentialManager); + } + + [Test] + public void TestUseMemoryImplCredentialManager() + { + // arrange + SnowflakeCredentialManagerFactory.UseInMemoryCredentialManager(); + + // act + var credentialManager = SnowflakeCredentialManagerFactory.GetCredentialManager(); + + // assert + Assert.IsInstanceOf(credentialManager); + } + + [Test] + public void TestThatThrowsErrorWhenTryingToSetCredentialManagerToNull() + { + // act and assert + var exception = Assert.Throws(() => SnowflakeCredentialManagerFactory.SetCredentialManager(null)); + Assert.That(exception.Message, Does.Contain("Credential manager cannot be null. If you want to use the default credential manager, please call the UseDefaultCredentialManager method.")); + } + + [Test] + [Platform(Exclude = "Win")] + public void TestUseWindowsCredentialManagerFailsOnUnix() + { + // act + var thrown = Assert.Throws(SnowflakeCredentialManagerFactory.UseWindowsCredentialManager); + + // assert + Assert.AreEqual("Windows native credential manager implementation can be used only on Windows", thrown.Message); + } + + [Test] + [Platform("Win")] + public void TestUseFileCredentialManagerFailsOnWindows() + { + // act + var thrown = Assert.Throws(SnowflakeCredentialManagerFactory.UseFileCredentialManager); + + // assert + Assert.AreEqual("File credential manager implementation is not supported on Windows", thrown.Message); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index a57a9fb74..044ac5ddc 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -25,7 +25,7 @@ public void TestThatPropertiesAreParsed(TestCase testcase) testcase.SecurePassword); // assert - CollectionAssert.AreEquivalent(testcase.ExpectedProperties, properties); + CollectionAssert.IsSubsetOf(testcase.ExpectedProperties, properties); } [Test] @@ -104,6 +104,76 @@ public void TestFailWhenNoPasswordProvided(string connectionString, string passw Assert.That(exception.Message, Does.Contain("Required property PASSWORD is not provided")); } + [Test] + public void TestParsePasscode() + { + // arrange + var expectedPasscode = "abc"; + var connectionString = $"ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;PASSCODE={expectedPasscode}"; + + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // assert + Assert.AreEqual(expectedPasscode, properties[SFSessionProperty.PASSCODE]); + } + + [Test] + public void TestUsePasscodeFromSecureString() + { + // arrange + var expectedPasscode = "abc"; + var connectionString = $"ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword"; + var securePasscode = SecureStringHelper.Encode(expectedPasscode); + + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null, securePasscode); + + // assert + Assert.AreEqual(expectedPasscode, properties[SFSessionProperty.PASSCODE]); + } + + [Test] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;PASSCODE=")] + public void TestDoNotParsePasscodeWhenNotProvided(string connectionString) + { + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // assert + Assert.False(properties.TryGetValue(SFSessionProperty.PASSCODE, out _)); + } + + [Test] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;", "false")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=", "false")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=true", "true")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=TRUE", "TRUE")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=false", "false")] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=FALSE", "FALSE")] + public void TestParsePasscodeInPassword(string connectionString, string expectedPasscodeInPassword) + { + // act + var properties = SFSessionProperties.ParseConnectionString(connectionString, null); + + // assert + Assert.IsTrue(properties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPassword)); + Assert.AreEqual(expectedPasscodeInPassword, passcodeInPassword); + } + + [Test] + public void TestFailWhenInvalidPasscodeInPassword() + { + // arrange + var invalidConnectionString = "ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;passcodeInPassword=abc"; + + // act + var thrown = Assert.Throws(() => SFSessionProperties.ParseConnectionString(invalidConnectionString, null)); + + Assert.That(thrown.Message, Does.Contain("Invalid parameter value for PASSCODEINPASSWORD")); + } + [Test] [TestCase("DB", SFSessionProperty.DB, "\"testdb\"")] [TestCase("SCHEMA", SFSessionProperty.SCHEMA, "\"quotedSchema\"")] @@ -222,7 +292,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; @@ -258,7 +329,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseWithProxySettings = new TestCase() @@ -296,7 +368,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};useProxy=true;proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -336,7 +409,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -375,7 +449,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseWithIncludeRetryReason = new TestCase() @@ -411,7 +486,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseWithDisableQueryContextCache = new TestCase() @@ -446,7 +522,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLEQUERYCONTEXTCACHE=true" @@ -483,7 +560,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLE_CONSOLE_LOGIN=false" @@ -522,7 +600,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseUnderscoredAccountName = new TestCase() @@ -558,7 +637,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; var testCaseUnderscoredAccountNameWithEnabledAllowUnderscores = new TestCase() @@ -594,9 +674,11 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; + var testQueryTag = "Test QUERY_TAG 12345"; var testCaseQueryTag = new TestCase() { @@ -632,7 +714,8 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) }, - { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) } + { SFSessionProperty.DISABLE_SAML_URL_CHECK, DefaultValue(SFSessionProperty.DISABLE_SAML_URL_CHECK) }, + { SFSessionProperty.PASSCODEINPASSWORD, DefaultValue(SFSessionProperty.PASSCODEINPASSWORD) } } }; diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs index 262122b2d..3129bb509 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs @@ -5,6 +5,7 @@ using Newtonsoft.Json; using Snowflake.Data.Core; using NUnit.Framework; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Tests.Mock; namespace Snowflake.Data.Tests.UnitTests @@ -99,7 +100,7 @@ public void TestThatConfiguresEasyLogging(string configPath) : $"{simpleConnectionString}client_config_file={configPath};"; // act - new SFSession(connectionString, null, easyLoggingStarter.Object); + new SFSession(connectionString, null, null, easyLoggingStarter.Object); // assert easyLoggingStarter.Verify(starter => starter.Init(configPath)); @@ -157,5 +158,165 @@ public void TestHandlePasswordWithQuotations() // assert Assert.AreEqual(loginRequest.data.password, deserializedLoginRequest.data.password); } + + [Test] + public void TestHandlePasscodeParameter() + { + // arrange + var passcode = "123456"; + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test;passcode={passcode}", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.AreEqual(passcode, loginRequest.data.passcode); + Assert.AreEqual("passcode", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestHandlePasscodeAsSecureString() + { + // arrange + var passcode = "123456"; + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test;", null, SecureStringHelper.Encode(passcode), EasyLoggingStarter.Instance, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.AreEqual(passcode, loginRequest.data.passcode); + Assert.AreEqual("passcode", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestHandlePasscodeInPasswordParameter() + { + // arrange + var passcode = "123456"; + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test{passcode};passcodeInPassword=true;", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.IsNull(loginRequest.data.passcode); + Assert.AreEqual("passcode", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestPushWhenNoPasscodeAndPasscodeInPasswordIsFalse() + { + // arrange + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test;passcodeInPassword=false;", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.IsNull(loginRequest.data.passcode); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestPushAsDefaultSecondaryAuthentication() + { + // arrange + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession($"account=test;user=test;password=test", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.IsNull(loginRequest.data.passcode); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestPushMFAWithAuthenticationCacheMFAToken() + { + // arrange + var restRequester = new MockLoginMFATokenCacheRestRequester(); + var sfSession = new SFSession($"account=test;user=test;password=test;authenticator=username_password_mfa", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests.Dequeue(); + Assert.IsNull(loginRequest.data.passcode); + Assert.IsTrue(loginRequest.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestMFATokenCacheReturnedToSession() + { + // arrange + var testToken = "testToken1234"; + var restRequester = new MockLoginMFATokenCacheRestRequester(); + var sfSession = new SFSession($"account=test;user=test;password=test;authenticator=username_password_mfa", null, restRequester); + restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(SecureStringHelper.Decode(sfSession._mfaToken), testToken); + Assert.IsNull(loginRequest.data.passcode); + Assert.IsTrue(loginRequest.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("push", loginRequest.data.extAuthnDuoMethod); + } + + [Test] + public void TestMFATokenCacheUsedInNewConnection() + { + // arrange + var testToken = "testToken1234"; + var restRequester = new MockLoginMFATokenCacheRestRequester(); + var connectionString = $"account=test;user=test;password=test;authenticator=username_password_mfa"; + var sfSession = new SFSession(connectionString, null, restRequester); + restRequester.LoginResponses.Enqueue(new LoginResponseData() + { + mfaToken = testToken, + authResponseSessionInfo = new SessionInfo() + }); + sfSession.Open(); + var sfSessionWithCachedToken = new SFSession(connectionString, null, restRequester); + // act + sfSessionWithCachedToken.Open(); + + // assert + Assert.AreEqual(2, restRequester.LoginRequests.Count); + var firstLoginRequest = restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(SecureStringHelper.Decode(sfSession._mfaToken), testToken); + Assert.IsNull(firstLoginRequest.data.passcode); + Assert.IsTrue(firstLoginRequest.data.SessionParameters.TryGetValue(SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN, out var value) && (bool)value); + Assert.AreEqual("push", firstLoginRequest.data.extAuthnDuoMethod); + + var secondLoginRequest = restRequester.LoginRequests.Dequeue(); + Assert.AreEqual(secondLoginRequest.data.Token, testToken); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs b/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs index 82c59a63c..a25b263f9 100644 --- a/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs @@ -273,6 +273,14 @@ public void TestPasswordProperty() BasicMasking(@"somethingBefore=cccc;private_key_pwd=", @"somethingBefore=cccc;private_key_pwd=****"); BasicMasking(@"somethingBefore=cccc;private_key_pwd =aa;somethingNext=bbbb", @"somethingBefore=cccc;private_key_pwd =****"); BasicMasking(@"somethingBefore=cccc;private_key_pwd="" 'aa", @"somethingBefore=cccc;private_key_pwd=****"); + + BasicMasking(@"somethingBefore=cccc;passcode=aa", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=aa;somethingNext=bbbb", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=""aa"";somethingNext=bbbb", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=;somethingNext=bbbb", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode=", @"somethingBefore=cccc;passcode=****"); + BasicMasking(@"somethingBefore=cccc;passcode =aa;somethingNext=bbbb", @"somethingBefore=cccc;passcode =****"); + BasicMasking(@"somethingBefore=cccc;passcode="" 'aa", @"somethingBefore=cccc;passcode=****"); } [Test] diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs index 7d2b1a603..da5863475 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs @@ -10,13 +10,13 @@ namespace Snowflake.Data.Tests.UnitTests.Session public class SessionOrCreationTokensTest { private SFSession _session = new SFSession("account=test;user=test;password=test", null); - + [Test] public void TestNoBackgroundSessionsToCreateWhenInitialisedWithSession() { // arrange var sessionOrTokens = new SessionOrCreationTokens(_session); - + // act var backgroundCreationTokens = sessionOrTokens.BackgroundSessionCreationTokens(); @@ -32,14 +32,14 @@ public void TestReturnFirstCreationToken() .Select(_ => sessionCreationTokenCounter.NewToken()) .ToList(); var sessionOrTokens = new SessionOrCreationTokens(tokens); - + // act var token = sessionOrTokens.SessionCreationToken(); - + // assert Assert.AreSame(tokens[0], token); } - + [Test] public void TestReturnCreationTokensFromTheSecondOneForBackgroundExecution() { @@ -49,10 +49,10 @@ public void TestReturnCreationTokensFromTheSecondOneForBackgroundExecution() .Select(_ => sessionCreationTokenCounter.NewToken()) .ToList(); var sessionOrTokens = new SessionOrCreationTokens(tokens); - + // act var backgroundTokens = sessionOrTokens.BackgroundSessionCreationTokens(); - + // assert Assert.AreEqual(2, backgroundTokens.Count); Assert.AreSame(tokens[1], backgroundTokens[0]); diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs index fca8f7de1..14115824e 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs @@ -71,17 +71,17 @@ public void TestOverrideSetPooling() [Test] [TestCase("account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443", "somePassword", " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;private_key=SomePrivateKey;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;token=someToken;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;private_key_pwd=somePrivateKeyPwd;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;proxyPassword=someProxyPassword;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("ACCOUNT=someAccount;DB=someDb;HOST=someHost;PASSWORD=somePassword;USER=SomeUser;PORT=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] - [TestCase("ACCOUNT=\"someAccount\";DB=\"someDb\";HOST=\"someHost\";PASSWORD=\"somePassword\";USER=\"SomeUser\";PORT=\"443\"", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;private_key=SomePrivateKey;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;token=someToken;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;private_key_pwd=somePrivateKeyPwd;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;passcode=123;user=SomeUser;proxyPassword=someProxyPassword;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("ACCOUNT=someAccount;DB=someDb;HOST=someHost;PASSWORD=somePassword;passcode=123;USER=SomeUser;PORT=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("ACCOUNT=\"someAccount\";DB=\"someDb\";HOST=\"someHost\";PASSWORD=\"somePassword\";PASSCODE=\"123\";USER=\"SomeUser\";PORT=\"443\"", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] public void TestPoolIdentificationBasedOnConnectionString(string connectionString, string password, string expectedPoolIdentification) { // arrange - var securePassword = password == null ? null : new NetworkCredential("", password).SecurePassword; + var securePassword = password == null ? null : SecureStringHelper.Encode(password); var pool = SessionPool.CreateSessionPool(connectionString, securePassword); // act diff --git a/Snowflake.Data.Tests/UnitTests/Tools/DirectoryInformationTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/DirectoryInformationTest.cs new file mode 100644 index 000000000..8167e484a --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Tools/DirectoryInformationTest.cs @@ -0,0 +1,53 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.Tools +{ + [TestFixture] + public class DirectoryInformationTest + { + [Test] + [TestCaseSource(nameof(OldCreatingDatesTestCases))] + public void TestIsCreatedEarlierThanSeconds(DateTime? createdDate, DateTime utcNow) + { + // arrange + var directoryInformation = new DirectoryInformation(true, createdDate); + + // act + var result = directoryInformation.IsCreatedEarlierThanSeconds(60, utcNow); + + // assert + Assert.AreEqual(true, result); + } + + [Test] + [TestCaseSource(nameof(NewCreatingDatesTestCases))] + public void TestIsNotCreatedEarlierThanSeconds(bool dirExists, DateTime? createdDate, DateTime utcNow) + { + // arrange + var directoryInformation = new DirectoryInformation(dirExists, createdDate); + + // act + var result = directoryInformation.IsCreatedEarlierThanSeconds(60, utcNow); + + // assert + Assert.AreEqual(false, result); + } + + internal static IEnumerable OldCreatingDatesTestCases() + { + yield return new object[] { DateTime.UtcNow.AddMinutes(-2), DateTime.UtcNow }; + yield return new object[] { DateTime.UtcNow.AddSeconds(-61), DateTime.UtcNow }; + } + + internal static IEnumerable NewCreatingDatesTestCases() + { + yield return new object[] { true, DateTime.UtcNow.AddSeconds(-30), DateTime.UtcNow }; + yield return new object[] { true, DateTime.UtcNow.AddSeconds(30), DateTime.UtcNow }; + yield return new object[] { true, DateTime.UtcNow, DateTime.UtcNow }; + yield return new object[] { false, null, DateTime.UtcNow }; + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Tools/DirectoryUnixInformationTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/DirectoryUnixInformationTest.cs new file mode 100644 index 000000000..8610a58a3 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Tools/DirectoryUnixInformationTest.cs @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System.IO; +using Mono.Unix; +using NUnit.Framework; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.Tools +{ + [TestFixture] + public class DirectoryUnixInformationTest + { + private const long UserId = 5; + private const long AnotherUserId = 6; + static readonly string s_directoryFullName = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + + [Test] + [TestCase(FileAccessPermissions.UserWrite)] + [TestCase(FileAccessPermissions.UserRead)] + [TestCase(FileAccessPermissions.UserExecute)] + [TestCase(FileAccessPermissions.UserReadWriteExecute)] + public void TestSafeDirectory(FileAccessPermissions securePermissions) + { + // arrange + var dirInfo = new DirectoryUnixInformation(s_directoryFullName, true, securePermissions, UserId); + + // act + var isSafe = dirInfo.IsSafe(UserId); + + // assert + Assert.True(isSafe); + } + + [Test] + [TestCase(FileAccessPermissions.UserReadWriteExecute | FileAccessPermissions.GroupRead)] + [TestCase(FileAccessPermissions.UserReadWriteExecute | FileAccessPermissions.OtherRead)] + public void TestUnsafePermissions(FileAccessPermissions unsecurePermissions) + { + // arrange + var dirInfo = new DirectoryUnixInformation(s_directoryFullName, true, unsecurePermissions, UserId); + + // act + var isSafe = dirInfo.IsSafe(UserId); + + // assert + Assert.False(isSafe); + } + + [Test] + public void TestSafeExactlyDirectory() + { + // arrange + var dirInfo = new DirectoryUnixInformation(s_directoryFullName, true, FileAccessPermissions.UserReadWriteExecute, UserId); + + // act + var isSafe = dirInfo.IsSafeExactly(UserId); + + // assert + Assert.True(isSafe); + } + + [Test] + [TestCase(FileAccessPermissions.UserRead)] + [TestCase(FileAccessPermissions.UserReadWriteExecute | FileAccessPermissions.GroupRead)] + [TestCase(FileAccessPermissions.UserReadWriteExecute | FileAccessPermissions.OtherRead)] + public void TestUnsafeExactlyPermissions(FileAccessPermissions unsecurePermissions) + { + // arrange + var dirInfo = new DirectoryUnixInformation(s_directoryFullName, true, unsecurePermissions, UserId); + + // act + var isSafe = dirInfo.IsSafeExactly(UserId); + + // assert + Assert.False(isSafe); + } + + [Test] + public void TestOwnedByOthers() + { + // arrange + var dirInfo = new DirectoryUnixInformation(s_directoryFullName, true, FileAccessPermissions.UserReadWriteExecute, UserId); + + // act + var isSafe = dirInfo.IsSafe(AnotherUserId); + var isSafeExactly = dirInfo.IsSafeExactly(AnotherUserId); + + // assert + Assert.False(isSafe); + Assert.False(isSafeExactly); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs index 14e2df121..9c471965f 100644 --- a/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs @@ -6,6 +6,7 @@ using Mono.Unix.Native; using NUnit.Framework; using Snowflake.Data.Core; +using Snowflake.Data.Core.CredentialManager.Infrastructure; using Snowflake.Data.Core.Tools; using static Snowflake.Data.Tests.UnitTests.Configuration.EasyLoggingConfigGenerator; @@ -15,7 +16,7 @@ namespace Snowflake.Data.Tests.Tools public class UnixOperationsTest { private static UnixOperations s_unixOperations; - private static readonly string s_workingDirectory = Path.Combine(Path.GetTempPath(), "easy_logging_test_configs_", Path.GetRandomFileName()); + private static readonly string s_workingDirectory = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); [OneTimeSetUp] public static void BeforeAll() @@ -38,16 +39,12 @@ public static void AfterAll() } [Test] + [Platform(Exclude = "Win")] public void TestDetectGroupOrOthersWritablePermissions( [ValueSource(nameof(GroupOrOthersWritablePermissions))] FilePermissions groupOrOthersWritablePermissions, [ValueSource(nameof(GroupNotWritablePermissions))] FilePermissions groupNotWritablePermissions, [ValueSource(nameof(OtherNotWritablePermissions))] FilePermissions otherNotWritablePermissions) { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - Assert.Ignore("skip test on Windows"); - } - // arrange var filePath = CreateConfigTempFile(s_workingDirectory, "random text"); var readWriteUserPermissions = FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; @@ -62,16 +59,12 @@ public void TestDetectGroupOrOthersWritablePermissions( } [Test] + [Platform(Exclude = "Win")] public void TestDetectGroupOrOthersNotWritablePermissions( [ValueSource(nameof(UserPermissions))] FilePermissions userPermissions, [ValueSource(nameof(GroupNotWritablePermissions))] FilePermissions groupNotWritablePermissions, [ValueSource(nameof(OtherNotWritablePermissions))] FilePermissions otherNotWritablePermissions) { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - Assert.Ignore("skip test on Windows"); - } - var filePath = CreateConfigTempFile(s_workingDirectory, "random text"); var filePermissions = userPermissions | groupNotWritablePermissions | otherNotWritablePermissions; Syscall.chmod(filePath, filePermissions); @@ -84,13 +77,10 @@ public void TestDetectGroupOrOthersNotWritablePermissions( } [Test] + [Platform(Exclude = "Win")] public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidations( [ValueSource(nameof(UserAllowedPermissions))] FilePermissions userAllowedPermissions) { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - Assert.Ignore("skip test on Windows"); - } var content = "random text"; var filePath = CreateConfigTempFile(s_workingDirectory, content); Syscall.chmod(filePath, userAllowedPermissions); @@ -103,7 +93,21 @@ public void TestReadAllTextCheckingPermissionsUsingTomlConfigurationFileValidati } [Test] - public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationValidations([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions, + [Platform(Exclude = "Win")] + public void TestWriteAllTextCheckingPermissionsUsingSFCredentialManagerFileValidations( + [ValueSource(nameof(UserAllowedWritePermissions))] FilePermissions userAllowedPermissions) + { + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + Syscall.chmod(filePath, userAllowedPermissions); + + // act and assert + Assert.DoesNotThrow(() => s_unixOperations.WriteAllText(filePath,"test", SFCredentialManagerFileImpl.Instance.ValidateFilePermissions)); + } + + [Test] + [Platform(Exclude = "Win")] + public void TestFailIfGroupOrOthersHavePermissionsToFileWhileReadingWithUnixValidationsUsingTomlConfig([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions, [ValueSource(nameof(GroupPermissions))] FilePermissions groupPermissions, [ValueSource(nameof(OthersPermissions))] FilePermissions othersPermissions) { @@ -112,10 +116,27 @@ public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationVal Assert.Ignore("Skip test when group and others have no permissions"); } - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + var content = "random text"; + var filePath = CreateConfigTempFile(s_workingDirectory, content); + + var filePermissions = userPermissions | groupPermissions | othersPermissions; + Syscall.chmod(filePath, filePermissions); + + // act and assert + Assert.Throws(() => s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions), "Attempting to read a file with too broad permissions assigned"); + } + + [Test] + [Platform(Exclude = "Win")] + public void TestFailIfGroupOrOthersHavePermissionsToFileWhileWritingWithUnixValidationsForCredentialManagerFile([ValueSource(nameof(UserReadWritePermissions))] FilePermissions userPermissions, + [ValueSource(nameof(GroupPermissions))] FilePermissions groupPermissions, + [ValueSource(nameof(OthersPermissions))] FilePermissions othersPermissions) + { + if(groupPermissions == 0 && othersPermissions == 0) { - Assert.Ignore("skip test on Windows"); + Assert.Ignore("Skip test when group and others have no permissions"); } + var content = "random text"; var filePath = CreateConfigTempFile(s_workingDirectory, content); @@ -123,7 +144,7 @@ public void TestFailIfGroupOrOthersHavePermissionsToFileWithTomlConfigurationVal Syscall.chmod(filePath, filePermissions); // act and assert - Assert.Throws(() => s_unixOperations.ReadAllText(filePath, TomlConnectionBuilder.ValidateFilePermissions), "Attempting to read a file with too broad permissions assigned"); + Assert.Throws(() => s_unixOperations.WriteAllText(filePath, "test", SFCredentialManagerFileImpl.Instance.ValidateFilePermissions), "Attempting to read or write a file with too broad permissions assigned"); } public static IEnumerable UserPermissions() @@ -186,6 +207,11 @@ public static IEnumerable UserAllowedPermissions() yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; } + public static IEnumerable UserAllowedWritePermissions() + { + yield return FilePermissions.S_IRUSR | FilePermissions.S_IWUSR; + } + public static IEnumerable GroupOrOthersReadablePermissions() { yield return 0; diff --git a/Snowflake.Data/Client/ISnowflakeCredentialManager.cs b/Snowflake.Data/Client/ISnowflakeCredentialManager.cs new file mode 100644 index 000000000..802d8fe21 --- /dev/null +++ b/Snowflake.Data/Client/ISnowflakeCredentialManager.cs @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Client +{ + public interface ISnowflakeCredentialManager + { + string GetCredentials(string key); + + void RemoveCredentials(string key); + + void SaveCredentials(string key, string token); + } +} diff --git a/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs b/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs new file mode 100644 index 000000000..124dd4d45 --- /dev/null +++ b/Snowflake.Data/Client/SnowflakeCredentialManagerFactory.cs @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Runtime.InteropServices; +using Snowflake.Data.Core; +using Snowflake.Data.Core.CredentialManager; +using Snowflake.Data.Core.CredentialManager.Infrastructure; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Client +{ + public class SnowflakeCredentialManagerFactory + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static readonly object s_credentialManagerLock = new object(); + private static readonly ISnowflakeCredentialManager s_defaultCredentialManager = GetDefaultCredentialManager(); + + private static ISnowflakeCredentialManager s_credentialManager = s_defaultCredentialManager; + + internal static string GetSecureCredentialKey(string host, string user, TokenType tokenType) + { + return $"{host.ToUpper()}:{user.ToUpper()}:{tokenType.ToString().ToUpper()}".ToSha256Hash(); + } + + + public static void UseDefaultCredentialManager() + { + SetCredentialManager(GetDefaultCredentialManager()); + } + + public static void UseInMemoryCredentialManager() + { + SetCredentialManager(SFCredentialManagerInMemoryImpl.Instance); + } + + public static void UseFileCredentialManager() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + throw new Exception("File credential manager implementation is not supported on Windows"); + } + SetCredentialManager(SFCredentialManagerFileImpl.Instance); + } + + public static void UseWindowsCredentialManager() + { + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + throw new Exception("Windows native credential manager implementation can be used only on Windows"); + } + SetCredentialManager(SFCredentialManagerWindowsNativeImpl.Instance); + } + + public static void SetCredentialManager(ISnowflakeCredentialManager customCredentialManager) + { + lock (s_credentialManagerLock) + { + if (customCredentialManager == null) + { + throw new SnowflakeDbException(SFError.INTERNAL_ERROR, + "Credential manager cannot be null. If you want to use the default credential manager, please call the UseDefaultCredentialManager method."); + } + + if (customCredentialManager == s_credentialManager) + { + s_logger.Info($"Credential manager is already set to: {customCredentialManager.GetType().Name}"); + return; + } + + s_logger.Info($"Setting the credential manager: {customCredentialManager.GetType().Name}"); + s_credentialManager = customCredentialManager; + } + } + + public static ISnowflakeCredentialManager GetCredentialManager() + { + var credentialManager = s_credentialManager; + var typeCredentialText = credentialManager == s_defaultCredentialManager ? "default" : "custom"; + s_logger.Info($"Using {typeCredentialText} credential manager: {credentialManager?.GetType().Name}"); + return credentialManager; + } + + private static ISnowflakeCredentialManager GetDefaultCredentialManager() + { + return RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + ? (ISnowflakeCredentialManager) + SFCredentialManagerWindowsNativeImpl.Instance + : SFCredentialManagerFileImpl.Instance; + } + } +} diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index 70fa642ea..716861713 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -76,6 +76,8 @@ public SecureString Password get; set; } + public SecureString Passcode { get; set; } + public bool IsOpen() { return _connectionState == ConnectionState.Open && SfSession != null; @@ -277,7 +279,7 @@ public override void Open() { FillConnectionStringFromTomlConfigIfNotSet(); OnSessionConnecting(); - SfSession = SnowflakeDbConnectionPool.GetSession(ConnectionString, Password); + SfSession = SnowflakeDbConnectionPool.GetSession(ConnectionString, Password, Passcode); if (SfSession == null) throw new SnowflakeDbException(SFError.INTERNAL_ERROR, "Could not open session"); logger.Debug($"Connection open with pooled session: {SfSession.sessionId}"); @@ -320,7 +322,7 @@ public override Task OpenAsync(CancellationToken cancellationToken) OnSessionConnecting(); FillConnectionStringFromTomlConfigIfNotSet(); return SnowflakeDbConnectionPool - .GetSessionAsync(ConnectionString, Password, cancellationToken) + .GetSessionAsync(ConnectionString, Password, Passcode, cancellationToken) .ContinueWith(previousTask => { if (previousTask.IsFaulted) diff --git a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs index fcee66e1a..fd10eadd8 100644 --- a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs @@ -31,16 +31,16 @@ private static IConnectionManager ConnectionManager } } - internal static SFSession GetSession(string connectionString, SecureString password) + internal static SFSession GetSession(string connectionString, SecureString password, SecureString passcode) { s_logger.Debug($"SnowflakeDbConnectionPool::GetSession"); - return ConnectionManager.GetSession(connectionString, password); + return ConnectionManager.GetSession(connectionString, password, passcode); } - internal static Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) + internal static Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug($"SnowflakeDbConnectionPool::GetSessionAsync"); - return ConnectionManager.GetSessionAsync(connectionString, password, cancellationToken); + return ConnectionManager.GetSessionAsync(connectionString, password, passcode, cancellationToken); } public static SnowflakeDbSessionPool GetPool(string connectionString, SecureString password) diff --git a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs index a26d542d3..2dba66594 100644 --- a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs @@ -34,6 +34,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat { // Only need to add the password to Data for basic authentication data.password = session.properties[SFSessionProperty.PASSWORD]; + SetSecondaryAuthenticationData(ref data); } } diff --git a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs index e39ec18f8..baba5f8a5 100644 --- a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs @@ -260,6 +260,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat // Add the token and proof key to the Data data.Token = _samlResponseToken; data.ProofKey = _proofKey; + SetSpecializedAuthenticatorData(ref data); } private string GetLoginUrl(string proofKey, int localPort) diff --git a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs index 7a41a8335..267f878aa 100644 --- a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs @@ -101,6 +101,24 @@ protected void Login() /// The login request data to update. protected abstract void SetSpecializedAuthenticatorData(ref LoginRequestData data); + protected void SetSecondaryAuthenticationData(ref LoginRequestData data) + { + if (session.properties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordString) + && bool.TryParse(passcodeInPasswordString, out var passcodeInPassword) + && passcodeInPassword) + { + data.extAuthnDuoMethod = "passcode"; + } else if (session.properties.TryGetValue(SFSessionProperty.PASSCODE, out var passcode) && !string.IsNullOrEmpty(passcode)) + { + data.extAuthnDuoMethod = "passcode"; + data.passcode = passcode; + } + else + { + data.extAuthnDuoMethod = "push"; + } + } + /// /// Builds a simple login request. Each authenticator will fill the Data part with their /// specialized information. The common Data attributes are already filled (clientAppId, @@ -122,10 +140,11 @@ private SFRestRequest BuildLoginRequest() SessionParameters = session.ParameterMap, Authenticator = authName, }; - SetSpecializedAuthenticatorData(ref data); - return session.BuildTimeoutRestRequest(loginUrl, new LoginRequest() { data = data }); + return data.HttpTimeout.HasValue ? + session.BuildTimeoutRestRequest(loginUrl, new LoginRequest() { data = data }, data.HttpTimeout.Value) : + session.BuildTimeoutRestRequest(loginUrl, new LoginRequest() { data = data }); } } @@ -187,6 +206,10 @@ internal static IAuthenticator GetAuthenticator(SFSession session) return new OAuthAuthenticator(session); } + else if (type.Equals(MFACacheAuthenticator.AuthName, StringComparison.InvariantCultureIgnoreCase)) + { + return new MFACacheAuthenticator(session); + } // Okta would provide a url of form: https://xxxxxx.okta.com or https://xxxxxx.oktapreview.com or https://vanity.url/snowflake/okta else if (type.Contains("okta") && type.StartsWith("https://")) { diff --git a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs index 7d86d02c9..44b9b8bec 100644 --- a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs @@ -75,6 +75,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat { // Add the token to the Data attribute data.Token = jwtToken; + SetSpecializedAuthenticatorData(ref data); } /// diff --git a/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs new file mode 100644 index 000000000..1e65ca376 --- /dev/null +++ b/Snowflake.Data/Core/Authenticator/MFACacheAuthenticator.cs @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Core.Authenticator +{ + class MFACacheAuthenticator : BaseAuthenticator, IAuthenticator + { + public const string AuthName = "username_password_mfa"; + private const int MfaLoginHttpTimeout = 60; + + internal MFACacheAuthenticator(SFSession session) : base(session, AuthName) + { + } + + /// + async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) + { + await base.LoginAsync(cancellationToken); + } + + /// + void IAuthenticator.Authenticate() + { + base.Login(); + } + + /// + protected override void SetSpecializedAuthenticatorData(ref LoginRequestData data) + { + // Only need to add the password to Data for basic authentication + data.password = session.properties[SFSessionProperty.PASSWORD]; + data.SessionParameters[SFSessionParameter.CLIENT_REQUEST_MFA_TOKEN] = true; + data.HttpTimeout = TimeSpan.FromSeconds(MfaLoginHttpTimeout); + if (!string.IsNullOrEmpty(session._mfaToken?.ToString())) + { + data.Token = SecureStringHelper.Decode(session._mfaToken); + } + SetSecondaryAuthenticationData(ref data); + } + } + +} diff --git a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs index f36d0353e..85599266e 100644 --- a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs @@ -1,7 +1,4 @@ using Snowflake.Data.Log; -using System; -using System.Collections.Generic; -using System.Text; using System.Threading; using System.Threading.Tasks; @@ -48,6 +45,7 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat data.Token = session.properties[SFSessionProperty.TOKEN]; // Remove the login name for an OAuth session data.loginName = ""; + SetSecondaryAuthenticationData(ref data); } } } diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index 7c364d3c5..164949864 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -248,6 +248,7 @@ private SamlRestRequest BuildSamlRestRequest(Uri ssoUrl, string onetimeToken) protected override void SetSpecializedAuthenticatorData(ref LoginRequestData data) { data.RawSamlResponse = _rawSamlTokenHtmlString; + SetSecondaryAuthenticationData(ref data); } private void VerifyUrls(Uri tokenOrSsoUrl, Uri sessionUrl) diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/CredentialsFileContent.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/CredentialsFileContent.cs new file mode 100644 index 000000000..3b03ba686 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/CredentialsFileContent.cs @@ -0,0 +1,11 @@ +using Newtonsoft.Json; +using KeyTokenDict = System.Collections.Generic.Dictionary; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class CredentialsFileContent + { + [JsonProperty(PropertyName = "tokens")] + internal KeyTokenDict Tokens { get; set; } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs new file mode 100644 index 000000000..89581483f --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileImpl.cs @@ -0,0 +1,318 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Mono.Unix; +using Mono.Unix.Native; +using Newtonsoft.Json; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; +using System; +using System.IO; +using System.Linq; +using System.Security; +using System.Threading; +using KeyTokenDict = System.Collections.Generic.Dictionary; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerFileImpl : ISnowflakeCredentialManager + { + internal const int CredentialCacheLockDurationSeconds = 60; + + internal const FilePermissions CredentialCacheLockDirPermissions = FilePermissions.S_IRUSR; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static readonly object s_lock = new object(); + + internal SFCredentialManagerFileStorage _fileStorage = null; + + private readonly FileOperations _fileOperations; + + private readonly DirectoryOperations _directoryOperations; + + private readonly UnixOperations _unixOperations; + + private readonly EnvironmentOperations _environmentOperations; + + public static readonly SFCredentialManagerFileImpl Instance = new SFCredentialManagerFileImpl(FileOperations.Instance, DirectoryOperations.Instance, UnixOperations.Instance, EnvironmentOperations.Instance); + + internal SFCredentialManagerFileImpl(FileOperations fileOperations, DirectoryOperations directoryOperations, UnixOperations unixOperations, EnvironmentOperations environmentOperations) + { + _fileOperations = fileOperations; + _directoryOperations = directoryOperations; + _unixOperations = unixOperations; + _environmentOperations = environmentOperations; + } + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting credentials for key: {key}"); + lock (s_lock) + { + InitializeFileStorageIfNeeded(); + s_logger.Debug($"Getting credentials from json file in {_fileStorage.JsonCacheFilePath} for key: {key}"); + var lockAcquired = AcquireLockWithRetries(); // additional fs level locking is to synchronize file access across many applications + if (!lockAcquired) + { + s_logger.Error("Failed to acquire lock for reading credentials"); + return string.Empty; + } + try + { + if (_fileOperations.Exists(_fileStorage.JsonCacheFilePath)) + { + var keyTokenPairs = ReadJsonFile(); + if (keyTokenPairs.TryGetValue(key, out string token)) + { + return token; + } + } + } + catch (Exception exception) + { + s_logger.Error("Failed to get credentials", exception); + throw; + } + finally + { + ReleaseLock(); + } + } + s_logger.Info("Unable to get credentials for the specified key"); + return string.Empty; + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing credentials for key: {key}"); + lock (s_lock) + { + InitializeFileStorageIfNeeded(); + s_logger.Debug($"Removing credentials from json file in {_fileStorage.JsonCacheFilePath} for key: {key}"); + var lockAcquired = AcquireLockWithRetries(); // additional fs level locking is to synchronize file access across many applications + if (!lockAcquired) + { + s_logger.Error("Failed to acquire lock for removing credentials"); + return; + } + try + { + if (_fileOperations.Exists(_fileStorage.JsonCacheFilePath)) + { + var keyTokenPairs = ReadJsonFile(); + keyTokenPairs.Remove(key); + WriteToJsonFile(keyTokenPairs); + } + } + catch (Exception exception) + { + s_logger.Error("Failed to remove credentials", exception); + throw; + } + finally + { + ReleaseLock(); + } + } + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving credentials for key: {key}"); + lock (s_lock) + { + InitializeFileStorageIfNeeded(); + s_logger.Debug($"Saving credentials to json file in {_fileStorage.JsonCacheFilePath} for key: {key}"); + var lockAcquired = AcquireLockWithRetries(); // additional fs level locking is to synchronize file access across many applications + if (!lockAcquired) + { + s_logger.Error("Failed to acquire lock for saving credentials"); + return; + } + try + { + KeyTokenDict keyTokenPairs = _fileOperations.Exists(_fileStorage.JsonCacheFilePath) ? ReadJsonFile() : new KeyTokenDict(); + keyTokenPairs[key] = token; + WriteToJsonFile(keyTokenPairs); + } + catch (Exception exception) + { + s_logger.Error("Failed to save credentials", exception); + throw; + } + finally + { + ReleaseLock(); + } + } + } + + private void WriteToJsonFile(KeyTokenDict keyTokenPairs) + { + var credentials = new CredentialsFileContent { Tokens = keyTokenPairs }; + var jsonString = JsonConvert.SerializeObject(credentials); + WriteContentToJsonFile(jsonString); + } + + private void WriteContentToJsonFile(string content) + { + s_logger.Debug($"Writing credentials to json file in {_fileStorage.JsonCacheFilePath}"); + if (!_directoryOperations.Exists(_fileStorage.JsonCacheDirectory)) + { + _directoryOperations.CreateDirectory(_fileStorage.JsonCacheDirectory); + } + if (_fileOperations.Exists(_fileStorage.JsonCacheFilePath)) + { + s_logger.Info($"The existing json file for credential cache in {_fileStorage.JsonCacheFilePath} will be overwritten"); + } + else + { + s_logger.Info($"Creating the json file for credential cache in {_fileStorage.JsonCacheFilePath}"); + var createFileResult = _unixOperations.CreateFileWithPermissions(_fileStorage.JsonCacheFilePath, + FilePermissions.S_IRUSR | FilePermissions.S_IWUSR); + if (createFileResult == -1) + { + var errorMessage = "Failed to create the JSON token cache file"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + _fileOperations.Write(_fileStorage.JsonCacheFilePath, content, ValidateFilePermissions); + } + + private KeyTokenDict ReadJsonFile() + { + string contentFile; + try + { + contentFile = _fileOperations.ReadAllText(_fileStorage.JsonCacheFilePath, ValidateFilePermissions); + } + catch (FileNotFoundException) + { + s_logger.Error("Failed to read the file with cached credentials because it does not exist"); + return new KeyTokenDict(); + } + try + { + var fileContent = JsonConvert.DeserializeObject(contentFile); + return (fileContent == null || fileContent.Tokens == null) ? new KeyTokenDict() : fileContent.Tokens; + } + catch (Exception) + { + s_logger.Error("Failed to parse the file with cached credentials"); + return new KeyTokenDict(); + } + } + + private void InitializeFileStorageIfNeeded() + { + if (_fileStorage != null) + return; + var fileStorage = new SFCredentialManagerFileStorage(_environmentOperations); + PrepareParentDirectory(fileStorage.JsonCacheDirectory); + PrepareSecureDirectory(fileStorage.JsonCacheDirectory); + _fileStorage = fileStorage; + } + + private void PrepareParentDirectory(string directory) + { + var parentDirectory = _directoryOperations.GetParentDirectoryInfo(directory); + if (!parentDirectory.Exists) + { + _directoryOperations.CreateDirectory(parentDirectory.FullName); + } + } + + private void PrepareSecureDirectory(string directory) + { + var unixDirectoryInfo = _unixOperations.GetDirectoryInfo(directory); + if (unixDirectoryInfo.Exists) + { + var userId = _unixOperations.GetCurrentUserId(); + if (!unixDirectoryInfo.IsSafeExactly(userId)) + { + SetSecureOwnershipAndPermissions(directory, userId); + } + } + else + { + var createResult = _unixOperations.CreateDirectoryWithPermissions(directory, FilePermissions.S_IRWXU); + if (createResult == -1) + { + throw new SecurityException($"Could not create directory: {directory}"); + } + } + } + + private void SetSecureOwnershipAndPermissions(string directory, long userId) + { + var groupId = _unixOperations.GetCurrentGroupId(); + var chownResult = _unixOperations.ChangeOwner(directory, (int) userId, (int) groupId); + if (chownResult == -1) + { + throw new SecurityException($"Could not set proper directory ownership for directory: {directory}"); + } + var chmodResult = _unixOperations.ChangePermissions(directory, FileAccessPermissions.UserReadWriteExecute); + if (chmodResult == -1) + { + throw new SecurityException($"Could not set proper directory permissions for directory: {directory}"); + } + } + + private bool AcquireLockWithRetries() => AcquireLock(5, TimeSpan.FromMilliseconds(50)); + + private bool AcquireLock(int numberOfAttempts, TimeSpan delayTime) + { + for (var i = 0; i < numberOfAttempts; i++) + { + if (AcquireLock()) + return true; + if (i + 1 < numberOfAttempts) + Thread.Sleep(delayTime); + } + return false; + } + + private bool AcquireLock() + { + if (!_directoryOperations.Exists(_fileStorage.JsonCacheDirectory)) + { + _directoryOperations.CreateDirectory(_fileStorage.JsonCacheDirectory); + } + var lockDirectoryInfo = _directoryOperations.GetDirectoryInfo(_fileStorage.JsonCacheLockPath); + if (lockDirectoryInfo.IsCreatedEarlierThanSeconds(CredentialCacheLockDurationSeconds, DateTime.UtcNow)) + { + s_logger.Warn($"File cache lock directory {_fileStorage.JsonCacheLockPath} created more than {CredentialCacheLockDurationSeconds} seconds ago. Removing the lock directory."); + ReleaseLock(); + } + else if (lockDirectoryInfo.Exists) + { + return false; + } + var result = _unixOperations.CreateDirectoryWithPermissions(_fileStorage.JsonCacheLockPath, CredentialCacheLockDirPermissions); + return result == 0; + } + + private void ReleaseLock() + { + _directoryOperations.Delete(_fileStorage.JsonCacheLockPath, false); + } + + internal void ValidateFilePermissions(UnixStream stream) + { + var allowedPermissions = new[] + { + FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite + }; + if (stream.OwnerUser.UserId != _unixOperations.GetCurrentUserId()) + throw new SecurityException("Attempting to read or write a file not owned by the effective user of the current process"); + if (stream.OwnerGroup.GroupId != _unixOperations.GetCurrentGroupId()) + throw new SecurityException("Attempting to read or write a file not owned by the effective group of the current process"); + if (!(allowedPermissions.Any(a => stream.FileAccessPermissions == a))) + throw new SecurityException("Attempting to read or write a file with too broad permissions assigned"); + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileStorage.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileStorage.cs new file mode 100644 index 000000000..335dbc83f --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerFileStorage.cs @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.IO; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerFileStorage + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + internal const string CredentialCacheDirectoryEnvironmentName = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"; + + internal const string CommonCacheDirectoryEnvironmentName = "XDG_CACHE_HOME"; + + internal const string CommonCacheDirectoryName = ".cache"; + + internal const string CredentialCacheDirName = "snowflake"; + + internal const string CredentialCacheFileName = "credential_cache_v1.json"; + + internal const string CredentialCacheLockName = CredentialCacheFileName + ".lck"; + + public string JsonCacheDirectory { get; private set; } + + public string JsonCacheFilePath { get; private set; } + + public string JsonCacheLockPath { get; private set; } + + public SFCredentialManagerFileStorage(EnvironmentOperations environmentOperations) + { + var snowflakeEnvBasedDirectory = environmentOperations.GetEnvironmentVariable(CredentialCacheDirectoryEnvironmentName); + if (!string.IsNullOrEmpty(snowflakeEnvBasedDirectory)) + { + InitializeForDirectory(snowflakeEnvBasedDirectory); + return; + } + var commonCacheEnvBasedDirectory = environmentOperations.GetEnvironmentVariable(CommonCacheDirectoryEnvironmentName); + if (!string.IsNullOrEmpty(commonCacheEnvBasedDirectory)) + { + InitializeForDirectory(Path.Combine(commonCacheEnvBasedDirectory, CredentialCacheDirName)); + return; + } + var homeBasedDirectory = HomeDirectoryProvider.HomeDirectory(environmentOperations); + if (string.IsNullOrEmpty(homeBasedDirectory)) + { + throw new Exception("Unable to identify credential cache directory"); + } + InitializeForDirectory(Path.Combine(homeBasedDirectory, CommonCacheDirectoryName, CredentialCacheDirName)); + } + + private void InitializeForDirectory(string directory) + { + JsonCacheDirectory = directory; + JsonCacheFilePath = Path.Combine(directory, CredentialCacheFileName); + JsonCacheLockPath = Path.Combine(directory, CredentialCacheLockName); + s_logger.Info($"Setting the json credential cache path to {JsonCacheLockPath}"); + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs new file mode 100644 index 000000000..ba8d4d9c4 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerInMemoryImpl.cs @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System.Collections.Generic; +using System.Security; +using System.Threading; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + internal class SFCredentialManagerInMemoryImpl : ISnowflakeCredentialManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim(); + + private Dictionary s_credentials = new Dictionary(); + + public static readonly SFCredentialManagerInMemoryImpl Instance = new SFCredentialManagerInMemoryImpl(); + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting credentials from memory for key: {key}"); + bool found; + SecureString secureToken; + _lock.EnterReadLock(); + try + { + found = s_credentials.TryGetValue(key, out secureToken); + } + finally + { + _lock.ExitReadLock(); + } + if (found) + { + return SecureStringHelper.Decode(secureToken); + } + s_logger.Info("Unable to get credentials for the specified key"); + return ""; + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing credentials from memory for key: {key}"); + _lock.EnterWriteLock(); + try + { + s_credentials.Remove(key); + } + finally + { + _lock.ExitWriteLock(); + } + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving credentials into memory for key: {key}"); + var secureToken = SecureStringHelper.Encode(token); + _lock.EnterWriteLock(); + try + { + s_credentials[key] = secureToken; + } + finally + { + _lock.ExitWriteLock(); + } + } + } +} diff --git a/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs new file mode 100644 index 000000000..5e8819b9b --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/Infrastructure/SFCredentialManagerWindowsNativeImpl.cs @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Microsoft.Win32.SafeHandles; +using System; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using Snowflake.Data.Client; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.CredentialManager.Infrastructure +{ + + internal class SFCredentialManagerWindowsNativeImpl : ISnowflakeCredentialManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim(); + + public static readonly SFCredentialManagerWindowsNativeImpl Instance = new SFCredentialManagerWindowsNativeImpl(); + + public string GetCredentials(string key) + { + s_logger.Debug($"Getting the credentials for key: {key}"); + bool success; + IntPtr nCredPtr; + _lock.EnterReadLock(); + try + { + success = CredRead(key, 1 /* Generic */, 0, out nCredPtr); + } + catch (Exception exception) + { + s_logger.Error($"Failed to get credentials", exception); + throw; + } + finally + { + _lock.ExitReadLock(); + } + + if (!success) + { + s_logger.Info($"Unable to get credentials for key: {key}"); + return ""; + } + + using (var critCred = new CriticalCredentialHandle(nCredPtr)) + { + var cred = critCred.GetCredential(); + return cred.CredentialBlob; + } + } + + public void RemoveCredentials(string key) + { + s_logger.Debug($"Removing the credentials for key: {key}"); + bool success; + _lock.EnterWriteLock(); + try + { + success = CredDelete(key, 1 /* Generic */, 0); + } + catch (Exception exception) + { + s_logger.Error($"Failed to remove credentials", exception); + throw; + } + finally + { + _lock.ExitWriteLock(); + } + if (!success) + { + s_logger.Info($"Unable to remove credentials because the specified key did not exist: {key}"); + } + } + + public void SaveCredentials(string key, string token) + { + s_logger.Debug($"Saving the credentials for key: {key}"); + byte[] byteArray = Encoding.Unicode.GetBytes(token); + Credential credential = new Credential(); + credential.AttributeCount = 0; + credential.Attributes = IntPtr.Zero; + credential.Comment = IntPtr.Zero; + credential.TargetAlias = IntPtr.Zero; + credential.Type = 1; // Generic + credential.Persist = 2; // Local Machine + credential.CredentialBlobSize = (uint)(byteArray == null ? 0 : byteArray.Length); + credential.TargetName = key; + credential.CredentialBlob = token; + credential.UserName = Environment.UserName; + + _lock.EnterWriteLock(); + try + { + CredWrite(ref credential, 0); + } + catch (Exception exception) + { + s_logger.Error($"Failed to save credentials", exception); + throw; + } + finally + { + _lock.ExitWriteLock(); + } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + private struct Credential + { + public uint Flags; + public uint Type; + [MarshalAs(UnmanagedType.LPWStr)] + public string TargetName; + public IntPtr Comment; + public System.Runtime.InteropServices.ComTypes.FILETIME LastWritten; + public uint CredentialBlobSize; + [MarshalAs(UnmanagedType.LPWStr)] + public string CredentialBlob; + public uint Persist; + public uint AttributeCount; + public IntPtr Attributes; + public IntPtr TargetAlias; + [MarshalAs(UnmanagedType.LPWStr)] + public string UserName; + } + + sealed class CriticalCredentialHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + public CriticalCredentialHandle(IntPtr handle) + { + SetHandle(handle); + } + + public Credential GetCredential() + { + var credential = (Credential)Marshal.PtrToStructure(handle, typeof(Credential)); + return credential; + } + + protected override bool ReleaseHandle() + { + if (IsInvalid) + { + return false; + } + + CredFree(handle); + SetHandleAsInvalid(); + return true; + } + } + + [DllImport("Advapi32.dll", EntryPoint = "CredDeleteW", CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern bool CredDelete(string target, uint type, int reservedFlag); + + [DllImport("Advapi32.dll", EntryPoint = "CredReadW", CharSet = CharSet.Unicode, SetLastError = true)] + static extern bool CredRead(string target, uint type, int reservedFlag, out IntPtr credentialPtr); + + [DllImport("Advapi32.dll", EntryPoint = "CredWriteW", CharSet = CharSet.Unicode, SetLastError = true)] + static extern bool CredWrite([In] ref Credential userCredential, [In] uint flags); + + [DllImport("Advapi32.dll", EntryPoint = "CredFree", SetLastError = true)] + static extern bool CredFree([In] IntPtr cred); + } +} diff --git a/Snowflake.Data/Core/CredentialManager/TokenType.cs b/Snowflake.Data/Core/CredentialManager/TokenType.cs new file mode 100644 index 000000000..cdeb063d2 --- /dev/null +++ b/Snowflake.Data/Core/CredentialManager/TokenType.cs @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Core.CredentialManager +{ + internal enum TokenType + { + [StringAttr(value = "ID_TOKEN")] + IdToken, + [StringAttr(value = "MFA_TOKEN")] + MFAToken + } +} diff --git a/Snowflake.Data/Core/ErrorMessages.resx b/Snowflake.Data/Core/ErrorMessages.resx index 3532f3394..664122e11 100755 --- a/Snowflake.Data/Core/ErrorMessages.resx +++ b/Snowflake.Data/Core/ErrorMessages.resx @@ -180,6 +180,9 @@ Snowflake type {0} is not supported for parameters. + + Invalid browser url "{0}" cannot be used for authentication. + Browser response timed out after {0} seconds. diff --git a/Snowflake.Data/Core/RestRequest.cs b/Snowflake.Data/Core/RestRequest.cs index 112743f77..de988895b 100644 --- a/Snowflake.Data/Core/RestRequest.cs +++ b/Snowflake.Data/Core/RestRequest.cs @@ -27,7 +27,7 @@ internal abstract class BaseRestRequest : IRestRequest internal static string REST_REQUEST_TIMEOUT_KEY = "TIMEOUT_PER_REST_REQUEST"; - // The default Rest timeout. Set to 120 seconds. + // The default Rest timeout. Set to 120 seconds. public static int DEFAULT_REST_RETRY_SECONDS_TIMEOUT = 120; internal Uri Url { get; set; } @@ -133,7 +133,7 @@ internal SFRestRequest() : base() public override string ToString() { - return String.Format("SFRestRequest {{url: {0}, request body: {1} }}", Url.ToString(), + return String.Format("SFRestRequest {{url: {0}, request body: {1} }}", Url.ToString(), jsonBody.ToString()); } @@ -259,12 +259,21 @@ class LoginRequestData [JsonProperty(PropertyName = "PROOF_KEY", NullValueHandling = NullValueHandling.Ignore)] internal string ProofKey { get; set; } + [JsonProperty(PropertyName = "EXT_AUTHN_DUO_METHOD", NullValueHandling = NullValueHandling.Ignore)] + internal string extAuthnDuoMethod { get; set; } + + [JsonProperty(PropertyName = "PASSCODE", NullValueHandling = NullValueHandling.Ignore)] + internal string passcode; + [JsonProperty(PropertyName = "SESSION_PARAMETERS", NullValueHandling = NullValueHandling.Ignore)] internal Dictionary SessionParameters { get; set; } + [JsonIgnore] + internal TimeSpan? HttpTimeout { get; set; } + public override string ToString() { - return String.Format("LoginRequestData {{ClientAppVersion: {0},\n AccountName: {1},\n loginName: {2},\n ClientEnv: {3},\n authenticator: {4} }}", + return String.Format("LoginRequestData {{ClientAppVersion: {0},\n AccountName: {1},\n loginName: {2},\n ClientEnv: {3},\n authenticator: {4} }}", clientAppVersion, accountName, loginName, clientEnv.ToString(), Authenticator); } } @@ -291,7 +300,7 @@ class LoginRequestClientEnv public override string ToString() { - return String.Format("{{ APPLICATION: {0}, OS_VERSION: {1}, NET_RUNTIME: {2}, NET_VERSION: {3}, INSECURE_MODE: {4} }}", + return String.Format("{{ APPLICATION: {0}, OS_VERSION: {1}, NET_RUNTIME: {2}, NET_VERSION: {3}, INSECURE_MODE: {4} }}", application, osVersion, netRuntime, netVersion, insecureMode); } } diff --git a/Snowflake.Data/Core/RestResponse.cs b/Snowflake.Data/Core/RestResponse.cs index b490ddcdc..4b827ef7f 100755 --- a/Snowflake.Data/Core/RestResponse.cs +++ b/Snowflake.Data/Core/RestResponse.cs @@ -17,9 +17,11 @@ abstract class BaseRestResponse [JsonProperty(PropertyName = "message")] internal String message { get; set; } + [JsonProperty(PropertyName = "code", NullValueHandling = NullValueHandling.Ignore)] internal int code { get; set; } + [JsonProperty(PropertyName = "success")] internal bool success { get; set; } @@ -92,6 +94,9 @@ internal class LoginResponseData [JsonProperty(PropertyName = "masterValidityInSeconds", NullValueHandling = NullValueHandling.Ignore)] internal int masterValidityInSeconds { get; set; } + + [JsonProperty(PropertyName = "mfaToken", NullValueHandling = NullValueHandling.Ignore)] + internal string mfaToken { get; set; } } internal class AuthenticatorResponseData diff --git a/Snowflake.Data/Core/SFError.cs b/Snowflake.Data/Core/SFError.cs old mode 100755 new mode 100644 index 44de969a1..b87dcd97f --- a/Snowflake.Data/Core/SFError.cs +++ b/Snowflake.Data/Core/SFError.cs @@ -3,6 +3,8 @@ */ using System; +using System.Collections.Generic; +using System.Linq; namespace Snowflake.Data.Core { @@ -92,7 +94,39 @@ public enum SFError STRUCTURED_TYPE_READ_ERROR, [SFErrorAttr(errorCode = 270062)] - STRUCTURED_TYPE_READ_DETAILED_ERROR + STRUCTURED_TYPE_READ_DETAILED_ERROR, + + [SFErrorAttr(errorCode = 390120)] + EXT_AUTHN_DENIED, + + [SFErrorAttr(errorCode = 390123)] + EXT_AUTHN_LOCKED, + + [SFErrorAttr(errorCode = 390126)] + EXT_AUTHN_TIMEOUT, + + [SFErrorAttr(errorCode = 390127)] + EXT_AUTHN_INVALID, + + [SFErrorAttr(errorCode = 390129)] + EXT_AUTHN_EXCEPTION, + } + + class SFMFATokenErrors + { + private static List InvalidMFATokenErrors = new List + { + SFError.EXT_AUTHN_DENIED, + SFError.EXT_AUTHN_LOCKED, + SFError.EXT_AUTHN_TIMEOUT, + SFError.EXT_AUTHN_INVALID, + SFError.EXT_AUTHN_EXCEPTION + }; + + public static bool IsInvalidMFATokenContinueError(int error) + { + return InvalidMFATokenErrors.Any(e => e.GetAttribute().errorCode == error); + } } class SFErrorAttr : Attribute diff --git a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs index febecbbce..538221b09 100644 --- a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs @@ -11,9 +11,9 @@ namespace Snowflake.Data.Core.Session internal sealed class ConnectionCacheManager : IConnectionManager { private readonly SessionPool _sessionPool = SessionPool.CreateSessionCache(); - public SFSession GetSession(string connectionString, SecureString password) => _sessionPool.GetSession(connectionString, password); - public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) - => _sessionPool.GetSessionAsync(connectionString, password, cancellationToken); + public SFSession GetSession(string connectionString, SecureString password, SecureString passcode) => _sessionPool.GetSession(connectionString, password, passcode); + public Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken) + => _sessionPool.GetSessionAsync(connectionString, password, passcode, cancellationToken); public bool AddSession(SFSession session) => _sessionPool.AddSession(session, false); public void ReleaseBusySession(SFSession session) => _sessionPool.ReleaseBusySession(session); public void ClearAllPools() => _sessionPool.ClearSessions(); diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs index 09bfa5821..6a0013bb0 100644 --- a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -29,16 +29,16 @@ internal ConnectionPoolManager() } } - public SFSession GetSession(string connectionString, SecureString password) + public SFSession GetSession(string connectionString, SecureString password, SecureString passcode) { s_logger.Debug($"ConnectionPoolManager::GetSession"); - return GetPool(connectionString, password).GetSession(); + return GetPool(connectionString, password).GetSession(passcode); } - public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) + public Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug($"ConnectionPoolManager::GetSessionAsync"); - return GetPool(connectionString, password).GetSessionAsync(cancellationToken); + return GetPool(connectionString, password).GetSessionAsync(passcode, cancellationToken); } public bool AddSession(SFSession session) diff --git a/Snowflake.Data/Core/Session/IConnectionManager.cs b/Snowflake.Data/Core/Session/IConnectionManager.cs index 01cfa3e8c..5d3885de4 100644 --- a/Snowflake.Data/Core/Session/IConnectionManager.cs +++ b/Snowflake.Data/Core/Session/IConnectionManager.cs @@ -10,8 +10,8 @@ namespace Snowflake.Data.Core.Session { internal interface IConnectionManager { - SFSession GetSession(string connectionString, SecureString password); - Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken); + SFSession GetSession(string connectionString, SecureString password, SecureString passcode = null); + Task GetSessionAsync(string connectionString, SecureString password, SecureString passcode, CancellationToken cancellationToken); bool AddSession(SFSession session); void ReleaseBusySession(SFSession session); void ClearAllPools(); diff --git a/Snowflake.Data/Core/Session/ISessionFactory.cs b/Snowflake.Data/Core/Session/ISessionFactory.cs index f9416de8d..fbc896fda 100644 --- a/Snowflake.Data/Core/Session/ISessionFactory.cs +++ b/Snowflake.Data/Core/Session/ISessionFactory.cs @@ -4,6 +4,6 @@ namespace Snowflake.Data.Core.Session { internal interface ISessionFactory { - SFSession NewSession(string connectionString, SecureString password); + SFSession NewSession(string connectionString, SecureString password, SecureString passcode); } } diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs old mode 100755 new mode 100644 index b6a0ebf79..6b7aedd77 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ @@ -14,6 +14,7 @@ using System.Threading.Tasks; using System.Net.Http; using System.Text.RegularExpressions; +using Snowflake.Data.Core.CredentialManager; using Snowflake.Data.Core.Session; using Snowflake.Data.Core.Tools; @@ -73,6 +74,8 @@ public class SFSession internal string ConnectionString { get; } internal SecureString Password { get; } + internal SecureString Passcode { get; } + private QueryContextCache _queryContextCache = new QueryContextCache(_defaultQueryContextCacheSize); private int _queryContextCacheSize = _defaultQueryContextCacheSize; @@ -98,6 +101,8 @@ public void SetPooling(bool isEnabled) internal String _queryTag; + internal SecureString _mfaToken; + internal void ProcessLoginResponse(LoginResponse authnResponse) { if (authnResponse.success) @@ -116,6 +121,12 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) { logger.Debug("Query context cache disabled."); } + if (!string.IsNullOrEmpty(authnResponse.data.mfaToken)) + { + _mfaToken = SecureStringHelper.Encode(authnResponse.data.mfaToken); + var key = SnowflakeCredentialManagerFactory.GetSecureCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken); + SnowflakeCredentialManagerFactory.GetCredentialManager().SaveCredentials(key, authnResponse.data.mfaToken); + } logger.Debug($"Session opened: {sessionId}"); _startTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); } @@ -128,6 +139,14 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) ""); logger.Error("Authentication failed", e); + if (SFMFATokenErrors.IsInvalidMFATokenContinueError(e.ErrorCode)) + { + logger.Info($"Unable to use cached MFA token is expired or invalid. Fails with the {e.Message}. ", e); + _mfaToken = null; + var mfaKey = SnowflakeCredentialManagerFactory.GetSecureCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken); + SnowflakeCredentialManagerFactory.GetCredentialManager().RemoveCredentials(mfaKey); + } + throw e; } } @@ -158,19 +177,22 @@ internal Uri BuildLoginUrl() /// A string in the form of "key1=value1;key2=value2" internal SFSession( String connectionString, - SecureString password) : this(connectionString, password, EasyLoggingStarter.Instance) + SecureString password, + SecureString passcode = null) : this(connectionString, password, passcode, EasyLoggingStarter.Instance) { } internal SFSession( String connectionString, SecureString password, + SecureString passcode, EasyLoggingStarter easyLoggingStarter) { _easyLoggingStarter = easyLoggingStarter; ConnectionString = connectionString; Password = password; - properties = SFSessionProperties.ParseConnectionString(ConnectionString, Password); + Passcode = passcode; + properties = SFSessionProperties.ParseConnectionString(ConnectionString, Password, Passcode); _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]); _disableConsoleLogin = bool.Parse(properties[SFSessionProperty.DISABLE_CONSOLE_LOGIN]); properties.TryGetValue(SFSessionProperty.USER, out _user); @@ -190,6 +212,12 @@ internal SFSession( _maxRetryCount = extractedProperties.maxHttpRetries; _maxRetryTimeout = extractedProperties.retryTimeout; _disableSamlUrlCheck = extractedProperties._disableSamlUrlCheck; + + if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var _authenticatorType) && _authenticatorType == "username_password_mfa") + { + var mfaKey = SnowflakeCredentialManagerFactory.GetSecureCredentialKey(properties[SFSessionProperty.HOST], properties[SFSessionProperty.USER], TokenType.MFAToken); + _mfaToken = SecureStringHelper.Encode(SnowflakeCredentialManagerFactory.GetCredentialManager().GetCredentials(mfaKey)); + } } catch (SnowflakeDbException e) { @@ -221,7 +249,11 @@ private void ValidateApplicationName(SFSessionProperties properties) } } - internal SFSession(String connectionString, SecureString password, IMockRestRequester restRequester) : this(connectionString, password) + internal SFSession(String connectionString, SecureString password, IMockRestRequester restRequester) : this(connectionString, password, null, EasyLoggingStarter.Instance, restRequester) + { + } + + internal SFSession(String connectionString, SecureString password, SecureString passcode, EasyLoggingStarter easyLoggingStarter, IMockRestRequester restRequester) : this(connectionString, password, passcode, easyLoggingStarter) { // Inject the HttpClient to use with the Mock requester restRequester.setHttpClient(_HttpClient); @@ -428,6 +460,19 @@ internal SFRestRequest BuildTimeoutRestRequest(Uri uri, Object body) }; } + internal SFRestRequest BuildTimeoutRestRequest(Uri uri, Object body, TimeSpan httpTimeout) + { + return new SFRestRequest() + { + jsonBody = body, + Url = uri, + authorizationToken = SF_AUTHORIZATION_BASIC, + RestTimeout = connectionTimeout, + HttpTimeout = httpTimeout, + _isLogin = true + }; + } + internal void UpdateSessionParameterMap(List parameterList) { logger.Debug("Update parameter map"); diff --git a/Snowflake.Data/Core/Session/SFSessionParameter.cs b/Snowflake.Data/Core/Session/SFSessionParameter.cs index 97fdcec23..7d25c6e01 100755 --- a/Snowflake.Data/Core/Session/SFSessionParameter.cs +++ b/Snowflake.Data/Core/Session/SFSessionParameter.cs @@ -14,5 +14,6 @@ internal enum SFSessionParameter QUERY_CONTEXT_CACHE_SIZE, DATE_OUTPUT_FORMAT, TIME_OUTPUT_FORMAT, + CLIENT_REQUEST_MFA_TOKEN, } } diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 07896ae14..5575f7c63 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. */ @@ -112,7 +112,11 @@ internal enum SFSessionProperty [SFSessionPropertyAttr(required = false, defaultValue = "true")] POOLINGENABLED, [SFSessionPropertyAttr(required = false, defaultValue = "false")] - DISABLE_SAML_URL_CHECK + DISABLE_SAML_URL_CHECK, + [SFSessionPropertyAttr(required = false, IsSecret = true)] + PASSCODE, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + PASSCODEINPASSWORD } class SFSessionPropertyAttr : Attribute @@ -181,7 +185,7 @@ public override int GetHashCode() return base.GetHashCode(); } - internal static SFSessionProperties ParseConnectionString(string connectionString, SecureString password) + internal static SFSessionProperties ParseConnectionString(string connectionString, SecureString password, SecureString passcode = null) { logger.Info("Start parsing connection string."); var builder = new DbConnectionStringBuilder(); @@ -257,7 +261,13 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin properties[SFSessionProperty.PASSWORD] = SecureStringHelper.Decode(password); } + if (passcode != null && passcode.Length > 0) + { + properties[SFSessionProperty.PASSCODE] = SecureStringHelper.Decode(passcode); + } + ValidateAuthenticator(properties); + ValidatePasscodeInPassword(properties); properties.IsPoolingEnabledValueProvided = properties.IsNonEmptyValueProvided(SFSessionProperty.POOLINGENABLED); CheckSessionProperties(properties); ValidateFileTransferMaxBytesInMemoryProperty(properties); @@ -303,7 +313,8 @@ private static void ValidateAuthenticator(SFSessionProperties properties) OktaAuthenticator.AUTH_NAME, OAuthAuthenticator.AUTH_NAME, KeyPairAuthenticator.AUTH_NAME, - ExternalBrowserAuthenticator.AUTH_NAME + ExternalBrowserAuthenticator.AUTH_NAME, + MFACacheAuthenticator.AuthName }; if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator)) @@ -318,6 +329,23 @@ private static void ValidateAuthenticator(SFSessionProperties properties) } } + private static void ValidatePasscodeInPassword(SFSessionProperties properties) + { + if (properties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passCodeInPassword)) + { + if (!bool.TryParse(passCodeInPassword, out _)) + { + var errorMessage = $"Invalid value of {SFSessionProperty.PASSCODEINPASSWORD.ToString()} parameter"; + logger.Error(errorMessage); + throw new SnowflakeDbException( + new Exception(errorMessage), + SFError.INVALID_CONNECTION_PARAMETER_VALUE, + "", + SFSessionProperty.PASSCODEINPASSWORD.ToString()); + } + } + } + internal bool IsNonEmptyValueProvided(SFSessionProperty property) => TryGetValue(property, out var propertyValueStr) && !string.IsNullOrEmpty(propertyValueStr); diff --git a/Snowflake.Data/Core/Session/SessionFactory.cs b/Snowflake.Data/Core/Session/SessionFactory.cs index 2eb0ba6df..a1795ba10 100644 --- a/Snowflake.Data/Core/Session/SessionFactory.cs +++ b/Snowflake.Data/Core/Session/SessionFactory.cs @@ -4,9 +4,9 @@ namespace Snowflake.Data.Core.Session { internal class SessionFactory : ISessionFactory { - public SFSession NewSession(string connectionString, SecureString password) + public SFSession NewSession(string connectionString, SecureString password, SecureString passcode) { - return new SFSession(connectionString, password); + return new SFSession(connectionString, password, passcode, EasyLoggingStarter.Instance); } } } diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index de66c2240..abadd88e5 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -9,11 +9,13 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Client; +using Snowflake.Data.Core.Authenticator; using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; namespace Snowflake.Data.Core.Session { + sealed class SessionPool : IDisposable { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); @@ -133,34 +135,71 @@ internal void ValidateSecurePassword(SecureString password) private string ExtractPassword(SecureString password) => password == null ? string.Empty : SecureStringHelper.Decode(password); - internal SFSession GetSession(string connStr, SecureString password) + internal SFSession GetSession(string connStr, SecureString password, SecureString passcode) { s_logger.Debug("SessionPool::GetSession" + PoolIdentification()); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidateMinPoolSizeWithPasscode(sessionProperties, passcode); if (!GetPooling()) - return NewNonPoolingSession(connStr, password); - var sessionOrCreateTokens = GetIdleSession(connStr); + return NewNonPoolingSession(connStr, password, passcode); + var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AuthName; + var sessionOrCreateTokens = GetIdleSession(connStr, isMfaAuthentication ? 1 : int.MaxValue); if (sessionOrCreateTokens.Session != null) { _sessionPoolEventHandler.OnSessionProvided(this); } ScheduleNewIdleSessions(connStr, password, sessionOrCreateTokens.BackgroundSessionCreationTokens()); WarnAboutOverridenConfig(); - return sessionOrCreateTokens.Session ?? NewSession(connStr, password, sessionOrCreateTokens.SessionCreationToken()); + var session = sessionOrCreateTokens.Session ?? NewSession(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken()); + if (isMfaAuthentication) + { + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); + } + return session; } - internal async Task GetSessionAsync(string connStr, SecureString password, CancellationToken cancellationToken) + private void ValidateMinPoolSizeWithPasscode(SFSessionProperties sessionProperties, SecureString passcode) + { + if (!GetPooling() || !IsMultiplePoolsVersion() || _poolConfig.MinPoolSize == 0) return; + var isUsingPasscode = (passcode != null && passcode.Length > 0) || (sessionProperties.IsNonEmptyValueProvided(SFSessionProperty.PASSCODE) || + (sessionProperties.TryGetValue(SFSessionProperty.PASSCODEINPASSWORD, out var passcodeInPasswordValue) && + bool.TryParse(passcodeInPasswordValue, out var isPasscodeinPassword) && isPasscodeinPassword)); + var isMfaAuthenticator = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && + authenticator == MFACacheAuthenticator.AuthName; + if(isUsingPasscode && !isMfaAuthenticator) + { + const string ErrorMessage = "Passcode with MinPoolSize feature of connection pool allowed only for username_password_mfa authentication"; + s_logger.Error(ErrorMessage + PoolIdentification()); + throw new SnowflakeDbException(SFError.INVALID_CONNECTION_STRING, ErrorMessage); + } + } + + internal async Task GetSessionAsync(string connStr, SecureString password, SecureString passcode, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::GetSessionAsync" + PoolIdentification()); + var sessionProperties = SFSessionProperties.ParseConnectionString(connStr, password); + ValidateMinPoolSizeWithPasscode(sessionProperties, passcode); if (!GetPooling()) - return await NewNonPoolingSessionAsync(connStr, password, cancellationToken).ConfigureAwait(false); - var sessionOrCreateTokens = GetIdleSession(connStr); + return await NewNonPoolingSessionAsync(connStr, password, passcode, cancellationToken).ConfigureAwait(false); + var isMfaAuthentication = sessionProperties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator) && authenticator == MFACacheAuthenticator.AuthName; + var sessionOrCreateTokens = GetIdleSession(connStr, isMfaAuthentication ? 1 : int.MaxValue); + WarnAboutOverridenConfig(); + if (sessionOrCreateTokens.Session != null) { _sessionPoolEventHandler.OnSessionProvided(this); } ScheduleNewIdleSessions(connStr, password, sessionOrCreateTokens.BackgroundSessionCreationTokens()); WarnAboutOverridenConfig(); - return sessionOrCreateTokens.Session ?? await NewSessionAsync(connStr, password, sessionOrCreateTokens.SessionCreationToken(), cancellationToken).ConfigureAwait(false); + var session = sessionOrCreateTokens.Session ?? + await NewSessionAsync(connStr, password, passcode, sessionOrCreateTokens.SessionCreationToken(), cancellationToken) + .ConfigureAwait(false); + if (isMfaAuthentication) + { + ScheduleNewIdleSessions(connStr, password, RegisterSessionCreationsToEnsureMinPoolSize()); + } + return session; + } private void ScheduleNewIdleSessions(string connStr, SecureString password, List tokens) @@ -172,7 +211,7 @@ private void ScheduleNewIdleSession(string connStr, SecureString password, Sessi { Task.Run(() => { - var session = NewSession(connStr, password, token); + var session = NewSession(connStr, password, null, token); AddSession(session, false); // we don't want to ensure min pool size here because we could get into infinite recursion if expirationTimeout would be very low }); } @@ -187,17 +226,17 @@ private void WarnAboutOverridenConfig() internal bool IsConfigOverridden() => _configOverriden; - internal SFSession GetSession() => GetSession(ConnectionString, Password); + internal SFSession GetSession(SecureString passcode) => GetSession(ConnectionString, Password, passcode); - internal Task GetSessionAsync(CancellationToken cancellationToken) => - GetSessionAsync(ConnectionString, Password, cancellationToken); + internal Task GetSessionAsync(SecureString passcode, CancellationToken cancellationToken) => + GetSessionAsync(ConnectionString, Password, passcode, cancellationToken); internal void SetSessionPoolEventHandler(ISessionPoolEventHandler sessionPoolEventHandler) { _sessionPoolEventHandler = sessionPoolEventHandler; } - private SessionOrCreationTokens GetIdleSession(string connStr) + private SessionOrCreationTokens GetIdleSession(string connStr, int maxSessions) { s_logger.Debug("SessionPool::GetIdleSession" + PoolIdentification()); lock (_sessionPoolLock) @@ -215,7 +254,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(session); } s_logger.Debug("SessionPool::GetIdleSession - no thread was waiting for a session, but could not find any idle session available in the pool" + PoolIdentification()); - var sessionsCount = AllowedNumberOfNewSessionCreations(1); + var sessionsCount = AllowedNumberOfNewSessionCreations(1, maxSessions); if (sessionsCount > 0) { // there is no need to wait for a session since we can create new ones @@ -226,7 +265,7 @@ private SessionOrCreationTokens GetIdleSession(string connStr) return new SessionOrCreationTokens(WaitForSession(connStr)); } - private List RegisterSessionCreationsWhenReturningSessionToPool() + private List RegisterSessionCreationsToEnsureMinPoolSize() { var count = AllowedNumberOfNewSessionCreations(0); return RegisterSessionCreations(count); @@ -237,7 +276,7 @@ private List RegisterSessionCreations(int sessionsCount) = .Select(_ => _sessionCreationTokenCounter.NewToken()) .ToList(); - private int AllowedNumberOfNewSessionCreations(int atLeastCount) + private int AllowedNumberOfNewSessionCreations(int atLeastCount, int maxSessionsLimit = int.MaxValue) { // we are expecting to create atLeast 1 session in case of opening a connection (atLeastCount = 1) // but we have no expectations when closing a connection (atLeastCount = 0) @@ -252,7 +291,7 @@ private int AllowedNumberOfNewSessionCreations(int atLeastCount) { var maxSessionsToCreate = _poolConfig.MaxPoolSize - currentSize; var sessionsNeeded = Math.Max(_poolConfig.MinPoolSize - currentSize, atLeastCount); - var sessionsToCreate = Math.Min(sessionsNeeded, maxSessionsToCreate); + var sessionsToCreate = Math.Min(maxSessionsLimit, Math.Min(sessionsNeeded, maxSessionsToCreate)); s_logger.Debug($"SessionPool - allowed to create {sessionsToCreate} sessions, current pool size is {currentSize} out of {_poolConfig.MaxPoolSize}" + PoolIdentification()); return sessionsToCreate; } @@ -326,15 +365,15 @@ private SFSession ExtractIdleSession(string connStr) return null; } - private SFSession NewNonPoolingSession(String connectionString, SecureString password) => - NewSession(connectionString, password, _noPoolingSessionCreationTokenCounter.NewToken()); + private SFSession NewNonPoolingSession(String connectionString, SecureString password, SecureString passcode) => + NewSession(connectionString, password, passcode, _noPoolingSessionCreationTokenCounter.NewToken()); - private SFSession NewSession(String connectionString, SecureString password, SessionCreationToken sessionCreationToken) + private SFSession NewSession(String connectionString, SecureString password, SecureString passcode, SessionCreationToken sessionCreationToken) { s_logger.Debug("SessionPool::NewSession" + PoolIdentification()); try { - var session = s_sessionFactory.NewSession(connectionString, password); + var session = s_sessionFactory.NewSession(connectionString, password, passcode); session.Open(); s_logger.Debug("SessionPool::NewSession - opened" + PoolIdentification()); if (GetPooling() && !_underDestruction) @@ -374,13 +413,14 @@ private SFSession NewSession(String connectionString, SecureString password, Ses private Task NewNonPoolingSessionAsync( String connectionString, SecureString password, + SecureString passcode, CancellationToken cancellationToken) => - NewSessionAsync(connectionString, password, _noPoolingSessionCreationTokenCounter.NewToken(), cancellationToken); + NewSessionAsync(connectionString, password, passcode, _noPoolingSessionCreationTokenCounter.NewToken(), cancellationToken); - private Task NewSessionAsync(String connectionString, SecureString password, SessionCreationToken sessionCreationToken, CancellationToken cancellationToken) + private Task NewSessionAsync(String connectionString, SecureString password, SecureString passcode, SessionCreationToken sessionCreationToken, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::NewSessionAsync" + PoolIdentification()); - var session = s_sessionFactory.NewSession(connectionString, password); + var session = s_sessionFactory.NewSession(connectionString, password, passcode); return session .OpenAsync(cancellationToken) .ContinueWith(previousTask => @@ -457,7 +497,7 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) ReleaseBusySession(session); if (ensureMinPoolSize) { - ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsWhenReturningSessionToPool()); + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsToEnsureMinPoolSize()); } return false; } @@ -478,7 +518,7 @@ private Tuple> ReturnSessionToPool(SFSession se { _busySessionsCounter.Decrease(); var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -493,7 +533,7 @@ private Tuple> ReturnSessionToPool(SFSession se if (session.IsExpired(_poolConfig.ExpirationTimeout, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())) // checking again because we could have spent some time waiting for a lock { var sessionCreationTokens = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolState = GetCurrentState(); s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); @@ -508,7 +548,7 @@ private Tuple> ReturnSessionToPool(SFSession se _idleSessions.Add(session); _waitingForIdleSessionQueue.OnResourceIncrease(); var sessionCreationTokensAfterReturningToPool = ensureMinPoolSize - ? RegisterSessionCreationsWhenReturningSessionToPool() + ? RegisterSessionCreationsToEnsureMinPoolSize() : SessionOrCreationTokens.s_emptySessionCreationTokenList; var poolStateAfterReturningToPool = GetCurrentState(); s_logger.Debug($"returned session with sid {session.sessionId} to pool {poolStateAfterReturningToPool}" + PoolIdentification()); diff --git a/Snowflake.Data/Core/TomlConnectionBuilder.cs b/Snowflake.Data/Core/TomlConnectionBuilder.cs index a8c2396b1..481628802 100644 --- a/Snowflake.Data/Core/TomlConnectionBuilder.cs +++ b/Snowflake.Data/Core/TomlConnectionBuilder.cs @@ -153,9 +153,10 @@ private string ResolveConnectionTomlFile() return tomlPath; } + internal static void ValidateFilePermissions(UnixStream stream) { - var allowedPermissions = new FileAccessPermissions[] + var allowedPermissions = new[] { FileAccessPermissions.UserRead | FileAccessPermissions.UserWrite, FileAccessPermissions.UserRead diff --git a/Snowflake.Data/Core/Tools/DirectoryInformation.cs b/Snowflake.Data/Core/Tools/DirectoryInformation.cs new file mode 100644 index 000000000..183ffa678 --- /dev/null +++ b/Snowflake.Data/Core/Tools/DirectoryInformation.cs @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.IO; + +namespace Snowflake.Data.Core.Tools +{ + internal class DirectoryInformation + { + public bool Exists { get; private set; } + + public DateTime? CreationTimeUtc { get; private set; } + + public string FullName { get; private set; } + + public DirectoryInformation(DirectoryInfo directoryInfo) + { + Exists = directoryInfo.Exists; + CreationTimeUtc = directoryInfo.CreationTimeUtc; + FullName = directoryInfo.FullName; + } + + internal DirectoryInformation(bool exists, DateTime? creationTimeUtc) + { + Exists = exists; + CreationTimeUtc = creationTimeUtc; + } + + public bool IsCreatedEarlierThanSeconds(int seconds, DateTime utcNow) => + Exists && CreationTimeUtc?.AddSeconds(seconds) < utcNow; + } +} diff --git a/Snowflake.Data/Core/Tools/DirectoryOperations.cs b/Snowflake.Data/Core/Tools/DirectoryOperations.cs index 2d5d0424b..46254c85d 100644 --- a/Snowflake.Data/Core/Tools/DirectoryOperations.cs +++ b/Snowflake.Data/Core/Tools/DirectoryOperations.cs @@ -3,15 +3,42 @@ */ using System.IO; +using System.Runtime.InteropServices; namespace Snowflake.Data.Core.Tools { internal class DirectoryOperations { public static readonly DirectoryOperations Instance = new DirectoryOperations(); + private readonly UnixOperations _unixOperations; + + internal DirectoryOperations() : this(UnixOperations.Instance) + { + } + + internal DirectoryOperations(UnixOperations unixOperations) + { + _unixOperations = unixOperations; + } public virtual bool Exists(string path) => Directory.Exists(path); - + public virtual DirectoryInfo CreateDirectory(string path) => Directory.CreateDirectory(path); + + public virtual void Delete(string path, bool recursive) => Directory.Delete(path, recursive); + + public virtual DirectoryInformation GetDirectoryInfo(string path) => new DirectoryInformation(new DirectoryInfo(path)); + + public virtual DirectoryInformation GetParentDirectoryInfo(string path) => new DirectoryInformation(Directory.GetParent(path)); + + public virtual bool IsDirectorySafe(string path) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return true; + } + var unixInfo = _unixOperations.GetDirectoryInfo(path); + return unixInfo.IsSafe(_unixOperations.GetCurrentUserId()); + } } } diff --git a/Snowflake.Data/Core/Tools/DirectoryUnixInformation.cs b/Snowflake.Data/Core/Tools/DirectoryUnixInformation.cs new file mode 100644 index 000000000..d0fb960de --- /dev/null +++ b/Snowflake.Data/Core/Tools/DirectoryUnixInformation.cs @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Mono.Unix; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.Tools +{ + internal class DirectoryUnixInformation + { + private const FileAccessPermissions SafePermissions = FileAccessPermissions.UserReadWriteExecute; + private const FileAccessPermissions NotSafePermissions = FileAccessPermissions.AllPermissions & ~SafePermissions; + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + public string FullName { get; private set; } + public bool Exists { get; private set; } + public FileAccessPermissions Permissions { get; private set; } + public long Owner { get; private set; } + + public DirectoryUnixInformation(UnixDirectoryInfo directoryInfo) + { + FullName = directoryInfo.FullName; + Exists = directoryInfo.Exists; + if (Exists) + { + Permissions = directoryInfo.FileAccessPermissions; + Owner = directoryInfo.OwnerUserId; + } + } + + internal DirectoryUnixInformation(string fullName, bool exists, FileAccessPermissions permissions, long owner) + { + FullName = fullName; + Exists = exists; + Permissions = permissions; + Owner = owner; + } + + public bool IsSafe(long userId) + { + if (HasAnyOfPermissions(NotSafePermissions)) + { + s_logger.Warn($"Directory '{FullName}' permissions are too broad. It could be potentially accessed by group or others."); + return false; + } + if (!IsOwnedBy(userId)) + { + s_logger.Warn($"Directory '{FullName}' is not owned by the current user."); + return false; + } + return true; + } + + public bool IsSafeExactly(long userId) + { + if (SafePermissions != Permissions) + { + s_logger.Warn($"Directory '{FullName}' permissions are different than 700."); + return false; + } + if (!IsOwnedBy(userId)) + { + s_logger.Warn($"Directory '{FullName}' is not owned by the current user."); + return false; + } + return true; + } + + + private bool HasAnyOfPermissions(FileAccessPermissions permissions) => (permissions & Permissions) != 0; + + private bool IsOwnedBy(long userId) => Owner == userId; + + + } +} diff --git a/Snowflake.Data/Core/Tools/FileOperations.cs b/Snowflake.Data/Core/Tools/FileOperations.cs index 577bd54ee..1324423f2 100644 --- a/Snowflake.Data/Core/Tools/FileOperations.cs +++ b/Snowflake.Data/Core/Tools/FileOperations.cs @@ -20,6 +20,18 @@ public virtual bool Exists(string path) return File.Exists(path); } + public virtual void Write(string path, string content, Action validator = null) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + File.WriteAllText(path, content); + } + else + { + _unixOperations.WriteAllText(path, content, validator); + } + } + public virtual string ReadAllText(string path) { return ReadAllText(path, null); diff --git a/Snowflake.Data/Core/Tools/StringUtils.cs b/Snowflake.Data/Core/Tools/StringUtils.cs new file mode 100644 index 000000000..3e5c45767 --- /dev/null +++ b/Snowflake.Data/Core/Tools/StringUtils.cs @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Security.Cryptography; + +namespace Snowflake.Data.Core.Tools +{ + public static class StringUtils + { + internal static string ToSha256Hash(this string text) + { + if (string.IsNullOrEmpty(text)) + return string.Empty; + + using (var sha256Encoder = SHA256.Create()) + { + var sha256Hash = sha256Encoder.ComputeHash(System.Text.Encoding.UTF8.GetBytes(text)); + return BitConverter.ToString(sha256Hash).Replace("-", string.Empty); + } + } + } +} diff --git a/Snowflake.Data/Core/Tools/UnixOperations.cs b/Snowflake.Data/Core/Tools/UnixOperations.cs index 655b708ea..5133255da 100644 --- a/Snowflake.Data/Core/Tools/UnixOperations.cs +++ b/Snowflake.Data/Core/Tools/UnixOperations.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using System.Linq; using System.Security; using System.Text; using Mono.Unix; @@ -11,28 +12,50 @@ namespace Snowflake.Data.Core.Tools { + internal class UnixOperations { public static readonly UnixOperations Instance = new UnixOperations(); + public virtual int CreateFileWithPermissions(string path, FilePermissions permissions) + { + return Syscall.creat(path, permissions); + } + public virtual int CreateDirectoryWithPermissions(string path, FilePermissions permissions) { return Syscall.mkdir(path, permissions); } + public virtual FileAccessPermissions GetFilePermissions(string path) + { + var fileInfo = new UnixFileInfo(path); + return fileInfo.FileAccessPermissions; + } + public virtual FileAccessPermissions GetDirPermissions(string path) { var dirInfo = new UnixDirectoryInfo(path); return dirInfo.FileAccessPermissions; } + public virtual DirectoryUnixInformation GetDirectoryInfo(string path) + { + var dirInfo = new UnixDirectoryInfo(path); + return new DirectoryUnixInformation(dirInfo); + } + + public virtual long ChangeOwner(string path, int userId, int groupId) => Syscall.chown(path, userId, groupId); + + public virtual long ChangePermissions(string path, FileAccessPermissions permissions) => Syscall.chmod(path, (FilePermissions) permissions); + public virtual bool CheckFileHasAnyOfPermissions(string path, FileAccessPermissions permissions) { var fileInfo = new UnixFileInfo(path); return (permissions & fileInfo.FileAccessPermissions) != 0; } - public string ReadAllText(string path, Action validator) + public string ReadAllText(string path, Action validator) { var fileInfo = new UnixFileInfo(path: path); @@ -45,5 +68,29 @@ public string ReadAllText(string path, Action validator) } } } + + public void WriteAllText(string path, string content, Action validator) + { + var fileInfo = new UnixFileInfo(path: path); + + using (var handle = fileInfo.Open(FileMode.Create, FileAccess.ReadWrite, FilePermissions.S_IWUSR | FilePermissions.S_IRUSR)) + { + validator?.Invoke(handle); + using (var streamWriter = new StreamWriter(handle, Encoding.UTF8)) + { + streamWriter.Write(content); + } + } + } + + public virtual long GetCurrentUserId() + { + return Syscall.getuid(); + } + + public virtual long GetCurrentGroupId() + { + return Syscall.getgid(); + } } } diff --git a/Snowflake.Data/Logger/SecretDetector.cs b/Snowflake.Data/Logger/SecretDetector.cs index 59cd810d6..09c5981cf 100644 --- a/Snowflake.Data/Logger/SecretDetector.cs +++ b/Snowflake.Data/Logger/SecretDetector.cs @@ -92,7 +92,7 @@ private static string MaskCustomPatterns(string text) private const string ConnectionTokenPattern = @"(token|assertion content)(['""\s:=]+)([a-z0-9=/_\-+:]{8,})"; private const string TokenPropertyPattern = @"(token)(\s*=)(.*)"; private const string PasswordPattern = @"(password|passcode|pwd|proxypassword|private_key_pwd)(['""\s:=]+)([a-z0-9!""#$%&'\()*+,-./:;<=>?@\[\]\^_`{|}~]{6,})"; - private const string PasswordPropertyPattern = @"(password|proxypassword|private_key_pwd)(\s*=)(.*)"; + private const string PasswordPropertyPattern = @"(password|passcode|proxypassword|private_key_pwd)(\s*=)(.*)"; private static readonly Func[] s_maskFunctions = { MaskAWSServerSide, diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index f17124419..caac7ebed 100644 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 Snowflake.Data diff --git a/doc/Connecting.md b/doc/Connecting.md index 0999d6a58..2a2b4cc6c 100644 --- a/doc/Connecting.md +++ b/doc/Connecting.md @@ -50,6 +50,8 @@ The following table lists all valid connection properties: | EXPIRATIONTIMEOUT | No | Timeout for using each connection. Connections which last more than specified timeout are considered to be expired and are being removed from the pool. The default is 1 hour. Usage of units possible and allowed are: e. g. `360000ms` (milliseconds), `3600s` (seconds), `60m` (minutes) where seconds are default for a skipped postfix. Special values: `0` - immediate expiration of the connection just after its creation. Expiration timeout cannot be set to infinity. | | POOLINGENABLED | No | Boolean flag indicating if the connection should be a part of a pool. The default value is `true`. | | DISABLE_SAML_URL_CHECK | No | Specifies whether to check if the saml postback url matches the host url from the connection string. The default value is `false`. | +| PASSCODE | No | Passcode from your 2FA application to be used in Multi Factor Authentication. | +| PASSCODEINPASSWORD | No | Boolean flag indicating if MFA passcode is added to the password. |