Main Page | Namespace List | Class Hierarchy | Alphabetical List | Class List | File List | Namespace Members | Class Members | File Members

datatest.cpp

00001 #include "factory.h" 00002 #include "integer.h" 00003 #include "filters.h" 00004 #include "hex.h" 00005 #include "randpool.h" 00006 #include "files.h" 00007 #include "trunhash.h" 00008 #include <iostream> 00009 #include <memory> 00010 00011 USING_NAMESPACE(CryptoPP) 00012 USING_NAMESPACE(std) 00013 00014 RandomPool & GlobalRNG(); 00015 void RegisterFactories(); 00016 00017 typedef std::map<std::string, std::string> TestData; 00018 00019 class TestFailure : public Exception 00020 { 00021 public: 00022 TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {} 00023 }; 00024 00025 static const TestData *s_currentTestData = NULL; 00026 00027 static void OutputTestData(const TestData &v) 00028 { 00029 for (TestData::const_iterator i = v.begin(); i != v.end(); ++i) 00030 { 00031 cerr << i->first << ": " << i->second << endl; 00032 } 00033 } 00034 00035 static void SignalTestFailure() 00036 { 00037 OutputTestData(*s_currentTestData); 00038 throw TestFailure(); 00039 } 00040 00041 static void SignalTestError() 00042 { 00043 OutputTestData(*s_currentTestData); 00044 throw Exception(Exception::OTHER_ERROR, "Unexpected error during validation test"); 00045 } 00046 00047 class TestDataNameValuePairs : public NameValuePairs 00048 { 00049 public: 00050 TestDataNameValuePairs(const TestData &data) : m_data(data) {} 00051 00052 virtual bool GetVoidValue(const char *name, const std::type_info &valueType, void *pValue) const 00053 { 00054 TestData::const_iterator i = m_data.find(name); 00055 if (i == m_data.end()) 00056 return false; 00057 00058 const std::string &value = i->second; 00059 00060 if (valueType == typeid(int)) 00061 *reinterpret_cast<int *>(pValue) = atoi(value.c_str()); 00062 else if (valueType == typeid(Integer)) 00063 *reinterpret_cast<Integer *>(pValue) = Integer((std::string(value) + "h").c_str()); 00064 else if (valueType == typeid(ConstByteArrayParameter)) 00065 { 00066 m_temp.resize(0); 00067 StringSource(value, true, new HexDecoder(new StringSink(m_temp))); 00068 reinterpret_cast<ConstByteArrayParameter *>(pValue)->Assign((const byte *)m_temp.data(), m_temp.size(), true); 00069 } 00070 else if (valueType == typeid(const byte *)) 00071 { 00072 m_temp.resize(0); 00073 StringSource(value, true, new HexDecoder(new StringSink(m_temp))); 00074 *reinterpret_cast<const byte * *>(pValue) = (const byte *)m_temp.data(); 00075 } 00076 else 00077 throw ValueTypeMismatch(name, typeid(std::string), valueType); 00078 00079 return true; 00080 } 00081 00082 private: 00083 const TestData &m_data; 00084 mutable std::string m_temp; 00085 }; 00086 00087 const std::string & GetRequiredDatum(const TestData &data, const char *name) 00088 { 00089 TestData::const_iterator i = data.find(name); 00090 if (i == data.end()) 00091 SignalTestError(); 00092 return i->second; 00093 } 00094 00095 void PutDecodedDatumInto(const TestData &data, const char *name, BufferedTransformation &target) 00096 { 00097 std::string s1 = GetRequiredDatum(data, name), s2; 00098 00099 int repeat = 1; 00100 if (s1[0] == 'r') 00101 { 00102 repeat = atoi(s1.c_str()+1); 00103 s1 = s1.substr(s1.find(' ')+1); 00104 } 00105 00106 if (s1[0] == '\"') 00107 s2 = s1.substr(1, s1.find('\"', 1)-1); 00108 else if (s1.substr(0, 2) == "0x") 00109 StringSource(s1.substr(2), true, new HexDecoder(new StringSink(s2))); 00110 else 00111 StringSource(s1, true, new HexDecoder(new StringSink(s2))); 00112 00113 while (repeat--) 00114 target.Put((const byte *)s2.data(), s2.size()); 00115 } 00116 00117 std::string GetDecodedDatum(const TestData &data, const char *name) 00118 { 00119 std::string s; 00120 PutDecodedDatumInto(data, name, StringSink(s).Ref()); 00121 return s; 00122 } 00123 00124 void TestKeyPairValidAndConsistent(CryptoMaterial &pub, const CryptoMaterial &priv) 00125 { 00126 if (!pub.Validate(GlobalRNG(), 3)) 00127 SignalTestFailure(); 00128 if (!priv.Validate(GlobalRNG(), 3)) 00129 SignalTestFailure(); 00130 00131 /* EqualityComparisonFilter comparison; 00132 pub.Save(ChannelSwitch(comparison, "0")); 00133 pub.AssignFrom(priv); 00134 pub.Save(ChannelSwitch(comparison, "1")); 00135 comparison.ChannelMessageSeriesEnd("0"); 00136 comparison.ChannelMessageSeriesEnd("1"); 00137 */ 00138 } 00139 00140 void TestSignatureScheme(TestData &v) 00141 { 00142 std::string name = GetRequiredDatum(v, "Name"); 00143 std::string test = GetRequiredDatum(v, "Test"); 00144 00145 std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str())); 00146 std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str())); 00147 00148 TestDataNameValuePairs pairs(v); 00149 std::string keyFormat = GetRequiredDatum(v, "KeyFormat"); 00150 00151 if (keyFormat == "DER") 00152 verifier->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref()); 00153 else if (keyFormat == "Component") 00154 verifier->AccessMaterial().AssignFrom(pairs); 00155 00156 if (test == "Verify" || test == "NotVerify") 00157 { 00158 VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN); 00159 PutDecodedDatumInto(v, "Signature", verifierFilter); 00160 PutDecodedDatumInto(v, "Message", verifierFilter); 00161 verifierFilter.MessageEnd(); 00162 if (verifierFilter.GetLastResult() == (test == "NotVerify")) 00163 SignalTestFailure(); 00164 } 00165 else if (test == "PublicKeyValid") 00166 { 00167 if (!verifier->GetMaterial().Validate(GlobalRNG(), 3)) 00168 SignalTestFailure(); 00169 } 00170 else 00171 goto privateKeyTests; 00172 00173 return; 00174 00175 privateKeyTests: 00176 if (keyFormat == "DER") 00177 signer->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref()); 00178 else if (keyFormat == "Component") 00179 signer->AccessMaterial().AssignFrom(pairs); 00180 00181 if (test == "KeyPairValidAndConsistent") 00182 { 00183 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial()); 00184 } 00185 else if (test == "Sign") 00186 { 00187 SignerFilter f(GlobalRNG(), *signer, new HexEncoder(new FileSink(cout))); 00188 StringSource ss(GetDecodedDatum(v, "Message"), true, new Redirector(f)); 00189 SignalTestFailure(); 00190 } 00191 else if (test == "DeterministicSign") 00192 { 00193 SignalTestError(); 00194 assert(false); // TODO: implement 00195 } 00196 else if (test == "RandomSign") 00197 { 00198 SignalTestError(); 00199 assert(false); // TODO: implement 00200 } 00201 else if (test == "GenerateKey") 00202 { 00203 SignalTestError(); 00204 assert(false); 00205 } 00206 else 00207 { 00208 SignalTestError(); 00209 assert(false); 00210 } 00211 } 00212 00213 void TestAsymmetricCipher(TestData &v) 00214 { 00215 std::string name = GetRequiredDatum(v, "Name"); 00216 std::string test = GetRequiredDatum(v, "Test"); 00217 00218 std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str())); 00219 std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str())); 00220 00221 std::string keyFormat = GetRequiredDatum(v, "KeyFormat"); 00222 00223 if (keyFormat == "DER") 00224 { 00225 decryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PrivateKey")).Ref()); 00226 encryptor->AccessMaterial().Load(StringStore(GetDecodedDatum(v, "PublicKey")).Ref()); 00227 } 00228 else if (keyFormat == "Component") 00229 { 00230 TestDataNameValuePairs pairs(v); 00231 decryptor->AccessMaterial().AssignFrom(pairs); 00232 encryptor->AccessMaterial().AssignFrom(pairs); 00233 } 00234 00235 if (test == "DecryptMatch") 00236 { 00237 std::string decrypted, expected = GetDecodedDatum(v, "Plaintext"); 00238 StringSource ss(GetDecodedDatum(v, "Ciphertext"), true, new PK_DecryptorFilter(GlobalRNG(), *decryptor, new StringSink(decrypted))); 00239 if (decrypted != expected) 00240 SignalTestFailure(); 00241 } 00242 else if (test == "KeyPairValidAndConsistent") 00243 { 00244 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial()); 00245 } 00246 else 00247 { 00248 SignalTestError(); 00249 assert(false); 00250 } 00251 } 00252 00253 void TestSymmetricCipher(TestData &v) 00254 { 00255 std::string name = GetRequiredDatum(v, "Name"); 00256 std::string test = GetRequiredDatum(v, "Test"); 00257 00258 std::string key = GetDecodedDatum(v, "Key"); 00259 std::string ciphertext = GetDecodedDatum(v, "Ciphertext"); 00260 std::string plaintext = GetDecodedDatum(v, "Plaintext"); 00261 00262 TestDataNameValuePairs pairs(v); 00263 00264 if (test == "Encrypt") 00265 { 00266 std::auto_ptr<SymmetricCipher> encryptor(ObjectFactoryRegistry<SymmetricCipher, ENCRYPTION>::Registry().CreateObject(name.c_str())); 00267 encryptor->SetKey((const byte *)key.data(), key.size(), pairs); 00268 std::string encrypted; 00269 StringSource ss(plaintext, true, new StreamTransformationFilter(*encryptor, new StringSink(encrypted), StreamTransformationFilter::NO_PADDING)); 00270 if (encrypted != ciphertext) 00271 SignalTestFailure(); 00272 } 00273 else if (test == "Decrypt") 00274 { 00275 std::auto_ptr<SymmetricCipher> decryptor(ObjectFactoryRegistry<SymmetricCipher, DECRYPTION>::Registry().CreateObject(name.c_str())); 00276 decryptor->SetKey((const byte *)key.data(), key.size(), pairs); 00277 std::string decrypted; 00278 StringSource ss(ciphertext, true, new StreamTransformationFilter(*decryptor, new StringSink(decrypted), StreamTransformationFilter::NO_PADDING)); 00279 if (decrypted != plaintext) 00280 SignalTestFailure(); 00281 } 00282 else 00283 { 00284 SignalTestError(); 00285 assert(false); 00286 } 00287 } 00288 00289 void TestDigestOrMAC(TestData &v, bool testDigest) 00290 { 00291 std::string name = GetRequiredDatum(v, "Name"); 00292 std::string test = GetRequiredDatum(v, "Test"); 00293 00294 member_ptr<MessageAuthenticationCode> mac; 00295 member_ptr<HashTransformation> hash; 00296 HashTransformation *pHash = NULL; 00297 00298 if (testDigest) 00299 { 00300 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str())); 00301 pHash = hash.get(); 00302 } 00303 else 00304 { 00305 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str())); 00306 pHash = mac.get(); 00307 std::string key = GetDecodedDatum(v, "Key"); 00308 mac->SetKey((const byte *)key.c_str(), key.size()); 00309 } 00310 00311 if (test == "Verify" || test == "VerifyTruncated" || test == "NotVerify") 00312 { 00313 int digestSize = pHash->DigestSize(); 00314 if (test == "VerifyTruncated") 00315 digestSize = atoi(GetRequiredDatum(v, "TruncatedSize").c_str()); 00316 TruncatedHashModule thash(*pHash, digestSize); 00317 HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN); 00318 PutDecodedDatumInto(v, "Digest", verifierFilter); 00319 PutDecodedDatumInto(v, "Message", verifierFilter); 00320 verifierFilter.MessageEnd(); 00321 if (verifierFilter.GetLastResult() == (test == "NotVerify")) 00322 SignalTestFailure(); 00323 } 00324 else 00325 { 00326 SignalTestError(); 00327 assert(false); 00328 } 00329 } 00330 00331 bool GetField(std::istream &is, std::string &name, std::string &value) 00332 { 00333 name.resize(0); // GCC workaround: 2.95.3 doesn't have clear() 00334 is >> name; 00335 if (name.empty()) 00336 return false; 00337 00338 if (name[name.size()-1] != ':') 00339 SignalTestError(); 00340 name.erase(name.size()-1); 00341 00342 while (is.peek() == ' ') 00343 is.ignore(1); 00344 00345 // VC60 workaround: getline bug 00346 char buffer[128]; 00347 value.resize(0); // GCC workaround: 2.95.3 doesn't have clear() 00348 bool continueLine; 00349 00350 do 00351 { 00352 do 00353 { 00354 is.get(buffer, sizeof(buffer)); 00355 value += buffer; 00356 } 00357 while (buffer[0] != 0); 00358 is.clear(); 00359 is.ignore(); 00360 00361 if (value[value.size()-1] == '\\') 00362 { 00363 value.resize(value.size()-1); 00364 continueLine = true; 00365 } 00366 else 00367 continueLine = false; 00368 00369 std::string::size_type i = value.find('#'); 00370 if (i != std::string::npos) 00371 value.erase(i); 00372 } 00373 while (continueLine); 00374 00375 return true; 00376 } 00377 00378 void OutputPair(const NameValuePairs &v, const char *name) 00379 { 00380 Integer x; 00381 bool b = v.GetValue(name, x); 00382 assert(b); 00383 cout << name << ": \\\n "; 00384 x.Encode(HexEncoder(new FileSink(cout), false, 64, "\\\n ").Ref(), x.MinEncodedSize()); 00385 cout << endl; 00386 } 00387 00388 void OutputNameValuePairs(const NameValuePairs &v) 00389 { 00390 std::string names = v.GetValueNames(); 00391 string::size_type i = 0; 00392 while (i < names.size()) 00393 { 00394 string::size_type j = names.find_first_of (';', i); 00395 00396 if (j == string::npos) 00397 return; 00398 else 00399 { 00400 std::string name = names.substr(i, j-i); 00401 if (name.find(':') == string::npos) 00402 OutputPair(v, name.c_str()); 00403 } 00404 00405 i = j + 1; 00406 } 00407 } 00408 00409 void TestDataFile(const std::string &filename, unsigned int &totalTests, unsigned int &failedTests) 00410 { 00411 std::ifstream file(filename.c_str()); 00412 TestData v; 00413 s_currentTestData = &v; 00414 std::string name, value, lastAlgName; 00415 00416 while (file) 00417 { 00418 while (file.peek() == '#') 00419 file.ignore(INT_MAX, '\n'); 00420 00421 if (file.peek() == '\n') 00422 v.clear(); 00423 00424 if (!GetField(file, name, value)) 00425 break; 00426 v[name] = value; 00427 00428 if (name == "Test") 00429 { 00430 bool failed = true; 00431 std::string algType = GetRequiredDatum(v, "AlgorithmType"); 00432 00433 if (lastAlgName != GetRequiredDatum(v, "Name")) 00434 { 00435 lastAlgName = GetRequiredDatum(v, "Name"); 00436 cout << "\nTesting " << algType.c_str() << " algorithm " << lastAlgName.c_str() << ".\n"; 00437 } 00438 00439 try 00440 { 00441 if (algType == "Signature") 00442 TestSignatureScheme(v); 00443 else if (algType == "SymmetricCipher") 00444 TestSymmetricCipher(v); 00445 else if (algType == "AsymmetricCipher") 00446 TestAsymmetricCipher(v); 00447 else if (algType == "MessageDigest") 00448 TestDigestOrMAC(v, true); 00449 else if (algType == "MAC") 00450 TestDigestOrMAC(v, false); 00451 else if (algType == "FileList") 00452 TestDataFile(GetRequiredDatum(v, "Test"), totalTests, failedTests); 00453 else 00454 SignalTestError(); 00455 failed = false; 00456 } 00457 catch (TestFailure &) 00458 { 00459 cout << "\nTest failed.\n"; 00460 } 00461 catch (CryptoPP::Exception &e) 00462 { 00463 cout << "\nCryptoPP::Exception caught: " << e.what() << endl; 00464 } 00465 catch (std::exception &e) 00466 { 00467 cout << "\nstd::exception caught: " << e.what() << endl; 00468 } 00469 00470 if (failed) 00471 { 00472 cout << "Skipping to next test.\n"; 00473 failedTests++; 00474 } 00475 else 00476 cout << "." << flush; 00477 00478 totalTests++; 00479 } 00480 } 00481 } 00482 00483 bool RunTestDataFile(const char *filename) 00484 { 00485 RegisterFactories(); 00486 unsigned int totalTests = 0, failedTests = 0; 00487 TestDataFile(filename, totalTests, failedTests); 00488 cout << "\nTests complete. Total tests = " << totalTests << ". Failed tests = " << failedTests << ".\n"; 00489 if (failedTests != 0) 00490 cout << "SOME TESTS FAILED!\n"; 00491 return failedTests == 0; 00492 }

Generated on Fri Aug 27 13:59:53 2004 for Crypto++ by doxygen 1.3.8